From 7e70da9fb7a33b3e60b6d9cab9b52ebe282a810b Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 8 Dec 2025 23:20:32 +0800 Subject: [PATCH] Revert "[Kernel] add custom moe ops for prefill" (#4806) Reverts vllm-project/vllm-ascend#4194 as it broke CI in https://github.com/vllm-project/vllm-ascend/actions/runs/20030369087/job/57437687382?pr=4791 Co-authored-by: wangxiyuan --- csrc/build_aclnn.sh | 16 +- csrc/dispatch_layout/op_host/CMakeLists.txt | 49 -- .../op_host/aclnn_dispatch_layout.cpp | 64 -- .../op_host/aclnn_dispatch_layout.h | 50 -- .../op_host/dispatch_layout.cpp | 51 -- .../op_host/dispatch_layout_tiling.cpp | 211 ------ .../op_kernel/dispatch_layout.cpp | 17 - .../op_kernel/dispatch_layout.h | 153 ----- .../op_kernel/dispatch_layout_tiling.h | 20 - .../moe_combine_normal/op_host/CMakeLists.txt | 49 -- .../op_host/aclnn_moe_combine_normal.cpp | 77 --- .../op_host/aclnn_moe_combine_normal.h | 62 -- .../op_host/moe_combine_normal.cpp | 71 -- .../op_host/moe_combine_normal_tiling.cpp | 546 --------------- .../op_kernel/moe_combine_normal.cpp | 22 - .../op_kernel/moe_combine_normal.h | 377 ----------- .../op_kernel/moe_combine_normal_tiling.h | 33 - .../op_host/CMakeLists.txt | 49 -- .../op_host/aclnn_moe_dispatch_normal.cpp | 84 --- .../op_host/aclnn_moe_dispatch_normal.h | 24 - .../op_host/moe_dispatch_normal.cpp | 92 --- .../op_host/moe_dispatch_normal_tiling.cpp | 635 ------------------ .../op_kernel/moe_dispatch_normal.cpp | 56 -- .../op_kernel/moe_dispatch_normal.h | 540 --------------- .../op_kernel/moe_dispatch_normal_tiling.h | 30 - csrc/notify_dispatch/op_host/CMakeLists.txt | 49 -- .../op_host/aclnn_notify_dispatch.cpp | 84 --- .../op_host/aclnn_notify_dispatch.h | 61 -- .../op_host/notify_dispatch.cpp | 60 -- .../op_host/notify_dispatch_tiling.cpp | 306 --------- .../op_kernel/notify_dispatch.cpp | 57 -- .../op_kernel/notify_dispatch.h | 495 -------------- .../op_kernel/notify_dispatch_tiling.h | 23 - csrc/torch_binding.cpp | 262 -------- csrc/utils.h | 24 - csrc/utils/inc/kernel/comm_args.h | 72 -- csrc/utils/inc/kernel/data_copy.h | 68 -- csrc/utils/inc/kernel/moe_distribute_base.h | 199 ------ csrc/utils/inc/kernel/sync_collectives.h | 426 ------------ 39 files changed, 2 insertions(+), 5562 deletions(-) delete mode 100644 csrc/dispatch_layout/op_host/CMakeLists.txt delete mode 100644 csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp delete mode 100644 csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h delete mode 100644 csrc/dispatch_layout/op_host/dispatch_layout.cpp delete mode 100644 csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp delete mode 100644 csrc/dispatch_layout/op_kernel/dispatch_layout.cpp delete mode 100644 csrc/dispatch_layout/op_kernel/dispatch_layout.h delete mode 100644 csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h delete mode 100644 csrc/moe_combine_normal/op_host/CMakeLists.txt delete mode 100644 csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp delete mode 100644 csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h delete mode 100644 csrc/moe_combine_normal/op_host/moe_combine_normal.cpp delete mode 100644 csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp delete mode 100644 csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp delete mode 100644 csrc/moe_combine_normal/op_kernel/moe_combine_normal.h delete mode 100644 csrc/moe_combine_normal/op_kernel/moe_combine_normal_tiling.h delete mode 100644 csrc/moe_dispatch_normal/op_host/CMakeLists.txt delete mode 100644 csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.cpp delete mode 100644 csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h delete mode 100644 csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp delete mode 100644 csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp delete mode 100644 csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp delete mode 100644 csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h delete mode 100644 csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal_tiling.h delete mode 100644 csrc/notify_dispatch/op_host/CMakeLists.txt delete mode 100644 csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp delete mode 100644 csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h delete mode 100644 csrc/notify_dispatch/op_host/notify_dispatch.cpp delete mode 100644 csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp delete mode 100644 csrc/notify_dispatch/op_kernel/notify_dispatch.cpp delete mode 100644 csrc/notify_dispatch/op_kernel/notify_dispatch.h delete mode 100644 csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h delete mode 100644 csrc/utils/inc/kernel/comm_args.h delete mode 100644 csrc/utils/inc/kernel/data_copy.h delete mode 100644 csrc/utils/inc/kernel/moe_distribute_base.h delete mode 100644 csrc/utils/inc/kernel/sync_collectives.h diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 426b014d..758856b7 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -45,19 +45,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE" sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE" - - CUSTOM_OPS_ARRAY=( - "grouped_matmul_swiglu_quant_weight_nz_tensor_list" - "lightning_indexer" - "sparse_flash_attention" - "dispatch_ffn_combine" - "dispatch_gmm_combine_decode" - "moe_combine_normal" - "moe_dispatch_normal" - "dispatch_layout" - "notify_dispatch" - ) - CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine;dispatch_gmm_combine_decode;" SOC_ARG="ascend910_93" else # others @@ -70,7 +58,7 @@ fi cd csrc rm -rf build output echo "building custom ops $CUSTOM_OPS for $SOC_VERSION" -bash build.sh -n "$CUSTOM_OPS" -c "$SOC_ARG" +bash build.sh -n $CUSTOM_OPS -c $SOC_ARG # install custom ops to vllm_ascend/_cann_ops_custom ./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom diff --git a/csrc/dispatch_layout/op_host/CMakeLists.txt b/csrc/dispatch_layout/op_host/CMakeLists.txt deleted file mode 100644 index 1176644e..00000000 --- a/csrc/dispatch_layout/op_host/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_ops_compile_options( - OP_NAME DispatchLayout - OPTIONS --cce-auto-sync=off - -Wno-deprecated-declarations - -Werror -) - -target_sources(op_host_aclnnInner PRIVATE - dispatch_layout.cpp -) - -target_sources(opapi PRIVATE - aclnn_dispatch_layout.cpp -) - -if (NOT BUILD_OPEN_PROJECT) - target_sources(aclnn_ops_train PRIVATE - aclnn_dispatch_layout.cpp - ) - - target_sources(aclnn_ops_infer PRIVATE - aclnn_dispatch_layout.cpp - ) -endif () - -target_sources(optiling PRIVATE - dispatch_layout_tiling.cpp -) - -target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_sources(opsproto PRIVATE) - -file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_layout.h") - -install(FILES ${_GMM_Aclnn_header} - DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL -) diff --git a/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp b/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp deleted file mode 100644 index f5e822f6..00000000 --- a/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include -#include "graph/types.h" -#include "aclnn_dispatch_layout.h" - -enum NnopbaseHcclServerType { - NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, - NNOPBASE_HCCL_SERVER_TYPE_MTE, - NNOPBASE_HCCL_SERVER_TYPE_END -}; -extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); - -#ifdef __cplusplus -extern "C" { -#endif - -extern aclnnStatus aclnnInnerDispatchLayoutGetWorkspaceSize( - const aclTensor *topkIdx, - int64_t numTokens, - int64_t numRanks, - int64_t numExperts, - int64_t numTopk, - const aclTensor *numTokensPerRank, - const aclTensor *numTokensPerExpert, - const aclTensor *isTokenInRank, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -extern aclnnStatus aclnnInnerDispatchLayout( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -aclnnStatus aclnnDispatchLayoutGetWorkspaceSize( - const aclTensor *topkIdx, - int64_t numTokens, - int64_t numRanks, - int64_t numExperts, - int64_t numTopk, - const aclTensor *numTokensPerRank, - const aclTensor *numTokensPerExpert, - const aclTensor *isTokenInRank, - uint64_t *workspaceSize, - aclOpExecutor **executor) -{ - return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, numTokensPerRank, - numTokensPerExpert, isTokenInRank, workspaceSize, executor); -} - -aclnnStatus aclnnDispatchLayout( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream) -{ - if (NnopbaseSetHcclServerType) { - NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); - } - return aclnnInnerDispatchLayout(workspace, workspaceSize, executor, stream); -} - -#ifdef __cplusplus -} -#endif diff --git a/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h b/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h deleted file mode 100644 index 20926bab..00000000 --- a/csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef ACLNN_DISPATCH_LAYOUT_H_ -#define ACLNN_DISPATCH_LAYOUT_H_ - -#include "aclnn/acl_meta.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/* funtion: aclnnDispatchLayoutGetWorkspaceSize - * topkIdx : required - * numTokens : required - * numRanks : required - * numExperts : required - * numTopk : required - * numTokensPerRank : required - * numTokensPerExpert : required - * isTokenInRank : required - * workspaceSize : size of workspace(output). - * executor : executor context(output). - */ -__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize( - const aclTensor *topkIdx, - int64_t numTokens, - int64_t numRanks, - int64_t numExperts, - int64_t numTopk, - const aclTensor *numTokensPerRank, - const aclTensor *numTokensPerExpert, - const aclTensor *isTokenInRank, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -/* funtion: aclnnDispatchLayout - * workspace : workspace memory addr(input). - * workspaceSize : size of workspace(input). - * executor : executor context(input). - * stream : acl stream. - */ -__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayout( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/csrc/dispatch_layout/op_host/dispatch_layout.cpp b/csrc/dispatch_layout/op_host/dispatch_layout.cpp deleted file mode 100644 index 5b09b38b..00000000 --- a/csrc/dispatch_layout/op_host/dispatch_layout.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "register/op_def_registry.h" - -namespace ops { -class DispatchLayout : public OpDef { -public: - explicit DispatchLayout(const char *name) : OpDef(name) - { - this->Input("topkIdx") - .ParamType(REQUIRED) - .DataType({ge::DT_INT64}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); - - this->Attr("num_tokens").Int(); - this->Attr("num_ranks").Int(); - this->Attr("num_experts").Int(); - this->Attr("num_topk").Int(); - - this->Output("numTokensPerRank") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); - this->Output("numTokensPerExpert") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); - this->Output("isTokenInRank") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); - - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("jitCompile.flag", "static_true") - .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); - - this->AICore().AddConfig("ascend910_93", aicore_config); - } -}; - -OP_ADD(DispatchLayout); -} // namespace ops diff --git a/csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp b/csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp deleted file mode 100644 index af24a4e9..00000000 --- a/csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp +++ /dev/null @@ -1,211 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "log/ops_log.h" -#include "graph/utils/type_utils.h" -#include "register/op_def_registry.h" -#include "../op_kernel/dispatch_layout_tiling.h" -#include "tiling/platform/platform_ascendc.h" -#include "tiling/hccl/hccl_tiling.h" -#include "experiment/platform/platform/platform_infos_def.h" - -using namespace ge; -namespace { -constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0; - -constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0; -constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1; -constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2; - -constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0; -constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1; -constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2; -constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3; -const int64_t MAX_COMM_WORLD_SIZE = 384; -const int64_t MAX_MOE_EXPERTS_NUM = 384; -constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; -constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024; -constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024; - -constexpr uint32_t TWO_DIMS = 2; -constexpr uint32_t K_MAX = 16; -} // namespace - -namespace optiling { -static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &tilingData) -{ - OPS_LOG_D(nodeName, "numToken is %u.", tilingData.dispatchLayoutInfo.numTokens); - OPS_LOG_D(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks); - OPS_LOG_D(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts); - OPS_LOG_D(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk); - OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize); -} - -static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, - DispatchLayoutTilingData &tilingData) -{ - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); - - auto numTokensPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOKENS_INDEX)); - auto numRanksPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_RANKS_INDEX)); - auto numExpertsPtr = attrs->GetAttrPointer(ATTR_NUM_EXPERTS_INDEX); - auto numTopkPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOPK_INDEX)); - - OPS_CHECK(numTokensPtr == nullptr, OPS_LOG_E(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(numRanksPtr == nullptr, OPS_LOG_E(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(numExpertsPtr == nullptr, OPS_LOG_E(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(numTopkPtr == nullptr, OPS_LOG_E(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED); - - OPS_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE), - OPS_LOG_E(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", MAX_COMM_WORLD_SIZE, *numRanksPtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*numExpertsPtr <= 0) || (*numExpertsPtr > MAX_MOE_EXPERTS_NUM), - OPS_LOG_E(nodeName, "numExperts is invalid, only support (0, %ld], but got numExperts=%ld.", MAX_MOE_EXPERTS_NUM, *numExpertsPtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*numTopkPtr <= 0) || (*numTopkPtr > K_MAX), - OPS_LOG_E(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr), - return ge::GRAPH_FAILED); - - tilingData.dispatchLayoutInfo.numTokens = static_cast(*numTokensPtr); - tilingData.dispatchLayoutInfo.numRanks = static_cast(*numRanksPtr); - tilingData.dispatchLayoutInfo.numExperts = static_cast(*numExpertsPtr); - tilingData.dispatchLayoutInfo.numTopk = static_cast(*numTopkPtr); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) -{ - size_t *workSpaces = context->GetWorkspaceSizes(1); - OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); - workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE; - return ge::GRAPH_SUCCESS; -} - -static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName) -{ - auto topkIdx = context->GetInputDesc(INPUT_TOPK_IDX_INDEX); - auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX); - auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX); - auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX); - - OPS_CHECK(topkIdx == nullptr, OPS_LOG_E(nodeName, "topkIdx is null."), return false); - OPS_CHECK(numTokensPerRank == nullptr, OPS_LOG_E(nodeName, "numTokensPerRank is null."), return false); - OPS_CHECK(numTokensPerExpert == nullptr, OPS_LOG_E(nodeName, "numTokensPerExpert is null."), return false); - OPS_CHECK(isTokenInRank == nullptr, OPS_LOG_E(nodeName, "isTokenInRank is null."), return false); - - OPS_CHECK((topkIdx->GetDataType() != ge::DT_INT64), - OPS_LOG_E(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.", - static_cast(topkIdx->GetDataType())), return false); - OPS_CHECK((numTokensPerRank->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, "numTokensPerRank datatype is invalid, datatype should be int, but is %d.", - static_cast(numTokensPerRank->GetDataType())), return false); - OPS_CHECK((numTokensPerExpert->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, "numTokensPerExpert datatype is invalid, datatype should be int, but is %d.", - static_cast(numTokensPerExpert->GetDataType())), return false); - OPS_CHECK((isTokenInRank->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.", - static_cast(isTokenInRank->GetDataType())), return false); - - return true; -} - -static bool CheckTensorShape(gert::TilingContext *context, const char *nodeName) -{ - const gert::StorageShape *topkIdxStorageShape = context->GetInputShape(INPUT_TOPK_IDX_INDEX); - int64_t topkIdxDim0 = topkIdxStorageShape->GetStorageShape().GetDim(0); - int64_t topkIdxDim1 = topkIdxStorageShape->GetStorageShape().GetDim(1); - - OPS_CHECK((topkIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS), - OPS_LOG_E(nodeName, "topkIdx must be 2-dimension, but get %lu dim.", - topkIdxStorageShape->GetStorageShape().GetDimNum()), return false); - - return true; -} - -static ge::graphStatus TilingCheckTensor( - gert::TilingContext *context, const char *nodeName) -{ - OPS_CHECK(!CheckTensorDataType(context, nodeName), - OPS_LOG_E(nodeName, "params dataType is invalid."), - return ge::GRAPH_FAILED); - - OPS_CHECK(!CheckTensorShape(context, nodeName), - OPS_LOG_E(nodeName, "params dataType is invalid."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context) -{ - const char *nodeName = context->GetNodeName(); - DispatchLayoutTilingData *tilingData = context->GetTilingData(); - OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); - OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func."); - - OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), - return ge::GRAPH_FAILED); - - OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling check param failed."), - return ge::GRAPH_FAILED); - - OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling set workspace failed."), - return ge::GRAPH_FAILED); - - fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); - fe::PlatFormInfos &platformInfo = *platformInfoPtr; - - std::string socVersion; - (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); - - auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - uint32_t blockDim; - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint64_t ubSize = 0UL; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); - - blockDim = aivNum; - context->SetBlockDim(blockDim); - tilingData->dispatchLayoutInfo.totalUbSize = ubSize; - OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); - PrintTilingDataInfo(nodeName, *tilingData); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus DispatchLayoutTilingFunc(gert::TilingContext *context) -{ - fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); - fe::PlatFormInfos &platformInfo = *platformInfoPtr; - - std::string socVersion; - ge::graphStatus ret; - ret = DispatchLayoutTilingFuncImpl(context); - return ret; -} - -struct DispatchLayoutCompileInfo {}; -ge::graphStatus TilingParseForDispatchLayout(gert::TilingParseContext *context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_OPTILING(DispatchLayout) - .Tiling(DispatchLayoutTilingFunc) - .TilingParse(TilingParseForDispatchLayout); -} // namespace optiling diff --git a/csrc/dispatch_layout/op_kernel/dispatch_layout.cpp b/csrc/dispatch_layout/op_kernel/dispatch_layout.cpp deleted file mode 100644 index 13e24a13..00000000 --- a/csrc/dispatch_layout/op_kernel/dispatch_layout.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "kernel_operator.h" -#include "dispatch_layout.h" -#include "dispatch_layout_tiling.h" - - -extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, - GM_ADDR isTokenInRank, GM_ADDR workspace, GM_ADDR tiling) -{ - REGISTER_TILING_DEFAULT(DispatchLayoutTilingData); - GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling); - - TPipe pipe; - - DispatchLayout op; - op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, workspace, &pipe, &tilingData); - op.Process(); -} diff --git a/csrc/dispatch_layout/op_kernel/dispatch_layout.h b/csrc/dispatch_layout/op_kernel/dispatch_layout.h deleted file mode 100644 index ba261e7e..00000000 --- a/csrc/dispatch_layout/op_kernel/dispatch_layout.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef DISPATCH_LAYOUT_H -#define DISPATCH_LAYOUT_H - -#include -#include "kernel_operator.h" - -#include "../common/comm_args.h" -#include "../common/data_copy.h" -#include "../common/sync_collectives.h" -#include "../common/moe_distribute_base.h" -#include "dispatch_layout_tiling.h" - -using namespace AscendC; -using namespace Moe; - -constexpr uint32_t UB_32_ALIGN = 32U; -constexpr uint32_t AIV_NUM = 48; - -template -__aicore__ inline void SyncFunc() -{ - int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); - AscendC::SetFlag(eventID); - AscendC::WaitFlag(eventID); -} - -template -class DispatchLayout { - -public: - __aicore__ inline DispatchLayout() {}; - - __aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank, - GM_ADDR workspace, TPipe *pipe, const DispatchLayoutTilingData *tilingData) - { - numTokens_ = tilingData->dispatchLayoutInfo.numTokens; - numRanks_ = tilingData->dispatchLayoutInfo.numRanks; - numExperts_ = tilingData->dispatchLayoutInfo.numExperts; - numTopk_ = tilingData->dispatchLayoutInfo.numTopk; - tpipe_ = pipe; - - coreIdx_ = GetBlockIdx(); - uint32_t temp = numTokens_ / AIV_NUM; - uint32_t restNum = numTokens_ % AIV_NUM; - int64_t topkIdxOffset; - int64_t isTokenOffset; - tempTokens_ = temp; - if (coreIdx_ < restNum) { - tempTokens_++; - } - topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; - numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; - numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; - isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; - - if (coreIdx_ < restNum) { - topkIdxOffset = coreIdx_ * topkIdx32AlignIntLen_; - isTokenOffset = coreIdx_ * isTokenInRank32AlignIntLen_; - } else { - topkIdxOffset = restNum * Ceil((tempTokens_ + 1) * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN - + (coreIdx_ - restNum) * topkIdx32AlignIntLen_; - isTokenOffset = restNum * Ceil((tempTokens_ + 1) * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN - + (coreIdx_ - restNum) * isTokenInRank32AlignIntLen_; - } - - topkIdxGM_.SetGlobalBuffer((__gm__ int64_t*)(topkIdx + topkIdxOffset)); - numTokensPerRankGM_.SetGlobalBuffer((__gm__ T*)numTokensPerRank); - numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T*)numTokensPerExpert); - isTokenInRankGM_.SetGlobalBuffer((__gm__ T*)(isTokenInRank + isTokenOffset)); - - - } - - __aicore__ inline void Process() - { - tpipe_->Reset(); - tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_); - tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_); - tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_); - tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_); - tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T)); - - LocalTensor topkIdxTensor = topkIdxBuf_.AllocTensor(); - const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; - const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; - DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); - SyncFunc(); - - LocalTensor numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor(); - LocalTensor numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor(); - LocalTensor isTokenInRankTensor = isTokenInRankBuf_.AllocTensor(); - LocalTensor seenRankTensor = seenRankBuf_.AllocTensor(); - Duplicate(numTokensPerRankTensor, 0, numRanks_); - Duplicate(numTokensPerExpertTensor, 0, numExperts_); - Duplicate(isTokenInRankTensor, 0, tempTokens_ * numRanks_); - SyncFunc(); - - int experts_per_rank = numExperts_ / numRanks_; - for (int i = 0; i < tempTokens_; ++i) { - SyncFunc(); - Duplicate(seenRankTensor, 0, numRanks_); - SyncFunc(); - for (int j = 0; j < numTopk_; ++j) { - int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); - uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; - numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); - int rank_id = expert_idx / experts_per_rank; - if (!seenRankTensor.GetValue(rank_id)) { - uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; - isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); - seenRankTensor.SetValue(rank_id, 1); - numTokensPerRankTensor.SetValue(rank_id, per_rank_num); - } - } - } - - const DataCopyExtParams isTokenInRankDataCopyParams{1U, isTokenInRank32AlignIntLen_, 0U, 0U, 0U}; - DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); - AscendC::SetAtomicAdd(); - const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U}; - DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); - const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; - DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); - AscendC::SetAtomicNone(); - } - -private: - GlobalTensor topkIdxGM_; - GlobalTensor numTokensPerRankGM_; - GlobalTensor numTokensPerExpertGM_; - GlobalTensor isTokenInRankGM_; - - TBuf<> topkIdxBuf_; - TBuf<> numTokensPerRankBuf_; - TBuf<> numTokensPerExpertBuf_; - TBuf<> isTokenInRankBuf_; - TBuf<> seenRankBuf_; - - TPipe *tpipe_{nullptr}; - uint32_t numTokens_{0}; - uint32_t numRanks_{0}; - uint32_t numExperts_{0}; - uint32_t numTopk_{0}; - uint32_t coreIdx_{0}; - uint32_t tempTokens_{0}; - - uint32_t topkIdx32AlignIntLen_{0}; - uint32_t numTokensPerRank32AlignIntLen_{0}; - uint32_t numTokensPerExpert32AlignIntLen_{0}; - uint32_t isTokenInRank32AlignIntLen_{0}; -}; - -#endif // DISPATCH_LAYOUT_H diff --git a/csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h b/csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h deleted file mode 100644 index bf56f45a..00000000 --- a/csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef DISPATCH_LAYOUT_TILING_H -#define DISPATCH_LAYOUT_TILING_H - -#include "kernel_tiling/kernel_tiling.h" - -struct DispatchLayoutInfo { - uint32_t numTokens; - uint32_t numRanks; - uint32_t numExperts; - uint32_t numTopk; - uint64_t totalUbSize; -}; - -struct DispatchLayoutTilingData { - Mc2InitTiling mc2InitTiling; - Mc2CcTiling mc2CcTiling1; - DispatchLayoutInfo dispatchLayoutInfo; -}; - -#endif diff --git a/csrc/moe_combine_normal/op_host/CMakeLists.txt b/csrc/moe_combine_normal/op_host/CMakeLists.txt deleted file mode 100644 index 190adfe1..00000000 --- a/csrc/moe_combine_normal/op_host/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_ops_compile_options( - OP_NAME MoeCombineNormal - OPTIONS --cce-auto-sync=off - -Wno-deprecated-declarations - -Werror -) - -target_sources(op_host_aclnnInner PRIVATE - moe_combine_normal.cpp -) - -target_sources(opapi PRIVATE - aclnn_moe_combine_normal.cpp -) - -if (NOT BUILD_OPEN_PROJECT) - target_sources(aclnn_ops_train PRIVATE - aclnn_moe_combine_normal.cpp - ) - - target_sources(aclnn_ops_infer PRIVATE - aclnn_moe_combine_normal.cpp - ) -endif () - -target_sources(optiling PRIVATE - moe_combine_normal_tiling.cpp -) - -target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_sources(opsproto PRIVATE) - -file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_combine_normal.h") - -install(FILES ${_GMM_Aclnn_header} - DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL -) diff --git a/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp b/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp deleted file mode 100644 index 3b70f958..00000000 --- a/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include -#include "graph/types.h" -#include "aclnn_moe_combine_normal.h" - -enum NnopbaseHcclServerType { - NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, - NNOPBASE_HCCL_SERVER_TYPE_MTE, - NNOPBASE_HCCL_SERVER_TYPE_END -}; -extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); - -#ifdef __cplusplus -extern "C" { -#endif -extern aclnnStatus aclnnInnerMoeCombineNormalGetWorkspaceSize( - const aclTensor *recvX, - const aclTensor *tokenSrcInfo, - const aclTensor *epRecvCounts, - const aclTensor *recvTopkWeights, - const aclTensor *tpRecvCountsOptional, - char *epGroupName, - int64_t epWorldSize, - int64_t epRankId, - char *tpGroupNameOptional, - int64_t tpWorldSize, - int64_t tpRankId, - int64_t moeExpertNum, - int64_t globalBs, - const aclTensor *out, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -extern aclnnStatus aclnnInnerMoeCombineNormal( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize( - const aclTensor *recvX, - const aclTensor *tokenSrcInfo, - const aclTensor *epRecvCounts, - const aclTensor *recvTopkWeights, - const aclTensor *tpRecvCountsOptional, - char *epGroupName, - int64_t epWorldSize, - int64_t epRankId, - char *tpGroupNameOptional, - int64_t tpWorldSize, - int64_t tpRankId, - int64_t moeExpertNum, - int64_t globalBs, - const aclTensor *out, - uint64_t *workspaceSize, - aclOpExecutor **executor) -{ - return aclnnInnerMoeCombineNormalGetWorkspaceSize(recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights, - tpRecvCountsOptional, epGroupName, epWorldSize, epRankId, - tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum, - globalBs, out, workspaceSize, executor); -} - -aclnnStatus aclnnMoeCombineNormal( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream) -{ - if (NnopbaseSetHcclServerType) { - NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); - } - return aclnnInnerMoeCombineNormal(workspace, workspaceSize, executor, stream); -} - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h b/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h deleted file mode 100644 index 50ba7122..00000000 --- a/csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef ACLNN_MOE_COMBINE_NORMAL_H_ -#define ACLNN_MOE_COMBINE_NORMAL_H_ - -#include "aclnn/acl_meta.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/* funtion: aclnnMoeCombineGetWorkspaceSize - * recvX : required - * tokenSrcInfo : required - * epRecvCounts : required - * recvTopkWeights : required - * tpRecvCountsOptional : required - * epGroupName : optional - * epWorldSize : required - * epRankId : required - * tpGroupNameOptional : required - * tpWorldSize : optional - * tpRankId : optional - * moeExpertNum : optional - * globalBs : optional - * out : required - * workspaceSize : size of workspace(output). - * executor : executor context(output). - */ -__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize( - const aclTensor *recvX, - const aclTensor *tokenSrcInfo, - const aclTensor *epRecvCounts, - const aclTensor *recvTopkWeights, - const aclTensor *tpRecvCountsOptional, - char *epGroupName, - int64_t epWorldSize, - int64_t epRankId, - char *tpGroupNameOptional, - int64_t tpWorldSize, - int64_t tpRankId, - int64_t moeExpertNum, - int64_t globalBs, - const aclTensor *out, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -/* funtion: aclnnMoeCombine - * workspace : workspace memory addr(input). - * workspaceSize : size of workspace(input). - * executor : executor context(input). - * stream : acl stream. - */ -__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormal( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -#ifdef __cplusplus -} -#endif - -#endif \ No newline at end of file diff --git a/csrc/moe_combine_normal/op_host/moe_combine_normal.cpp b/csrc/moe_combine_normal/op_host/moe_combine_normal.cpp deleted file mode 100644 index 072ee43f..00000000 --- a/csrc/moe_combine_normal/op_host/moe_combine_normal.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include "register/op_def_registry.h" - -namespace ops { -class MoeCombineNormal : public OpDef { -public: - explicit MoeCombineNormal(const char* name) : OpDef(name) { - this->Input("recv_x") - .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("token_src_info") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("ep_recv_counts") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("recv_topk_weights") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("tp_recv_counts") - .ParamType(OPTIONAL) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - - this->Output("x") - .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Attr("ep_group_name").AttrType(REQUIRED).String(); - this->Attr("ep_world_size").AttrType(REQUIRED).Int(); - this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); - this->Attr("tp_group_name").AttrType(OPTIONAL).String(""); - this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0); - this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0); - this->Attr("moe_expert_num").AttrType(REQUIRED).Int(); - this->Attr("global_bs").AttrType(OPTIONAL).Int(0); - - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("jitCompile.flag", "static_true") - .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); - - this->AICore().AddConfig("ascend910_93", aicore_config); - this->MC2().HcclGroup({"ep_group_name", "tp_group_name"}); - } -}; - -OP_ADD(MoeCombineNormal); - -} // namespace ops \ No newline at end of file diff --git a/csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp b/csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp deleted file mode 100644 index 66b3ab3d..00000000 --- a/csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp +++ /dev/null @@ -1,546 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "register/tilingdata_base.h" -#include "tiling/tiling_api.h" -#include "log/ops_log.h" -#include "graph/utils/type_utils.h" -#include "register/op_def_registry.h" -#include "../op_kernel/moe_combine_normal_tiling.h" - -using namespace AscendC; -using namespace ge; - -namespace { - class Mc2TilingUtils { - public: - #define HCCL_BUFFSIZE "HCCL_BUFFSIZE" - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - if (getenv(HCCL_BUFFSIZE) == nullptr) { - OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(HCCL_BUFFSIZE)); - defaultWindowSize = std::stoi(envStr); - } catch (...) { - OPS_LOG_E("", "Unknown Exception encountered when parser env HCCL_BUFFERSIZE"); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } - }; - constexpr uint32_t RECV_X_INDEX = 0; - constexpr uint32_t TOKEN_SRC_INFO_INDEX = 1; - constexpr uint32_t EP_RECV_COUNTS_INDEX = 2; - constexpr uint32_t TOPK_WEIGHTS_INDEX = 3; - constexpr uint32_t TP_RECV_COUNTS_INDEX = 4; - constexpr uint32_t OUTPUT_X_INDEX = 0; - - constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; - constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; - constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; - constexpr uint32_t ATTR_GROUP_TP_INDEX = 3; - constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4; - constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5; - constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6; - constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; - - constexpr uint32_t TWO_DIMS = 2U; - constexpr uint32_t ONE_DIM = 1U; - constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll - constexpr uint32_t OP_TYPE_REDUCE_SCATTER = 7U; // numeric representation of ReduceScatter - - constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL; - constexpr int64_t MAX_EP_WORLD_SIZE = 384; - constexpr int64_t MIN_EP_WORLD_SIZE = 2; - constexpr int64_t MAX_TP_WORLD_SIZE = 2; - constexpr int64_t BS_UPPER_BOUND = 8000; - - constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; - constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes - constexpr int64_t MOE_EXPERT_MAX_NUM = 512; - constexpr int64_t K_MAX = 16; - constexpr int64_t H_MIN = 1024; - constexpr int64_t H_MAX = 7168; - constexpr uint64_t MB_SIZE = 1024UL * 1024UL; - constexpr uint64_t TRIPLE = 3; - constexpr uint64_t WIN_ADDR_ALIGN = 512UL; - constexpr uint64_t SCALE_RECV_IDX_BUFFER = 44UL; // scale32B + 3*4 src info - constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; - constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL; - constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL; - constexpr uint64_t UB_ALIGN = 32UL; - constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL; - - enum class CommQuantMode : int32_t { - NON_QUANT = 0, - INT12_QUANT = 1, - INT8_QUANT = 2 - }; - using CommQuantModeType = std::underlying_type; -} - -namespace optiling { - -// Specific to A3 -static void PrintTilingDataInfo(const char *nodeName, MoeCombineNormalTilingData& tilingData) -{ - OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeCombineNormalInfo.epWorldSize); - OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeCombineNormalInfo.tpWorldSize); - OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeCombineNormalInfo.epRankId); - OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeCombineNormalInfo.tpRankId); - OPS_LOG_D(nodeName, "expertShardType is %u.", tilingData.moeCombineNormalInfo.expertShardType); - OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeCombineNormalInfo.moeExpertNum); - OPS_LOG_D(nodeName, "moeExpertPerRankNum is %u.", tilingData.moeCombineNormalInfo.moeExpertPerRankNum); - OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeCombineNormalInfo.globalBs); - OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeCombineNormalInfo.bs); - OPS_LOG_D(nodeName, "k is %u.", tilingData.moeCombineNormalInfo.k); - OPS_LOG_D(nodeName, "h is %u.", tilingData.moeCombineNormalInfo.h); - OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeCombineNormalInfo.aivNum); - OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeCombineNormalInfo.totalUbSize); - OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeCombineNormalInfo.totalWinSize); -} - -static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData, - const char *nodeName, std::string &groupEp, std::string &groupTp) -{ - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return ge::GRAPH_FAILED); - - auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); - auto groupTpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_TP_INDEX)); - auto epWorldSizePtr = attrs->GetAttrPointer(ATTR_EP_WORLD_SIZE_INDEX); - auto tpWorldSizePtr = attrs->GetAttrPointer(ATTR_TP_WORLD_SIZE_INDEX); - auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); - auto tpRankIdPtr = attrs->GetAttrPointer(ATTR_TP_RANK_ID_INDEX); - auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); - - // Check for null - OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) || - (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), OPS_LOG_E(nodeName, "groupEp is invalid."), - return ge::GRAPH_FAILED); - OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSize is null."), return ge::GRAPH_FAILED); - OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSize is null."), return ge::GRAPH_FAILED); - OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankId is null."), return ge::GRAPH_FAILED); - OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankId is null."), return ge::GRAPH_FAILED); - OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNum is null."), return ge::GRAPH_FAILED); - - // Check if it meets uint32_t and other constraints - int64_t moeExpertNum = *moeExpertNumPtr; - int64_t epWorldSize = *epWorldSizePtr; - OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE), - OPS_LOG_E(nodeName, "epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.", - MIN_EP_WORLD_SIZE, MAX_EP_WORLD_SIZE, epWorldSize), return ge::GRAPH_FAILED); - OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE), - OPS_LOG_E(nodeName, "tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.", - MAX_TP_WORLD_SIZE, *tpWorldSizePtr), return ge::GRAPH_FAILED); - OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize), - OPS_LOG_E(nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.", - epWorldSize, *epRankIdPtr), return ge::GRAPH_FAILED); - - if (*tpWorldSizePtr > 1) { - OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr), - OPS_LOG_E(nodeName, "tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.", - *tpWorldSizePtr, *tpRankIdPtr), return ge::GRAPH_FAILED); - OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) || - (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), - OPS_LOG_E(nodeName, "groupTpPtr is null."), return ge::GRAPH_FAILED); - groupTp = std::string(groupTpPtr); - } else { - OPS_CHECK(*tpRankIdPtr != 0, - OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr), - return ge::GRAPH_FAILED); - } - OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM), - OPS_LOG_E(nodeName, "moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.", - MOE_EXPERT_MAX_NUM, moeExpertNum), return ge::GRAPH_FAILED); - int64_t moePerRankNum = moeExpertNum / epWorldSize; - int64_t curDispatchStatusNum = moePerRankNum * epWorldSize; - OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM), - OPS_LOG_E(nodeName, "The moe experts num must meet the conditions," - " (moeExpertNum / epWorldSize) * epWorldSize <= 1280, but cur is %ld.", - curDispatchStatusNum), return ge::GRAPH_FAILED); - - groupEp = std::string(groupEpPtr); - tilingData.moeCombineNormalInfo.epWorldSize = static_cast(epWorldSize); - tilingData.moeCombineNormalInfo.tpWorldSize = static_cast(*tpWorldSizePtr); - tilingData.moeCombineNormalInfo.epRankId = static_cast(*epRankIdPtr); - tilingData.moeCombineNormalInfo.tpRankId = static_cast(*tpRankIdPtr); - tilingData.moeCombineNormalInfo.moeExpertNum = static_cast(moeExpertNum); - - return ge::GRAPH_SUCCESS; -} - -static bool CheckInputTensorDim(gert::TilingContext *context, const char *nodeName) -{ - const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX); - OPS_CHECK(recvXStorageShape == nullptr, OPS_LOG_E(nodeName, "recvX is null."), return false); - OPS_CHECK(recvXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "recvX must be 2-dimension, but got %lu dim", - recvXStorageShape->GetStorageShape().GetDimNum()), return false); - OPS_LOG_D(nodeName, "recvX dim0 = %ld", recvXStorageShape->GetStorageShape().GetDim(0)); - OPS_LOG_D(nodeName, "recvX dim1 = %ld", recvXStorageShape->GetStorageShape().GetDim(1)); - - const gert::StorageShape *tokenSrcInfoStorageShape = context->GetInputShape(TOKEN_SRC_INFO_INDEX); - OPS_CHECK(tokenSrcInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoForCombine is null."), return false); - OPS_CHECK(tokenSrcInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, - OPS_LOG_E(nodeName, "tokenSrcInfoForCombine must be 1-dimension, but got %lu dim", - tokenSrcInfoStorageShape->GetStorageShape().GetDimNum()), return false); - OPS_LOG_D(nodeName, "tokenSrcInfoForCombine dim0 = %ld", tokenSrcInfoStorageShape->GetStorageShape().GetDim(0)); - - const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX); - OPS_CHECK(topkWeightsStorageShape == nullptr, OPS_LOG_E(nodeName, "topkWeights is null."), return false); - OPS_CHECK(topkWeightsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "topkWeights must be 2-dimension, but got %lu dim", - topkWeightsStorageShape->GetStorageShape().GetDimNum()), return false); - OPS_LOG_D(nodeName, "topkWeights dim0 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(0)); - OPS_LOG_D(nodeName, "topkWeights dim1 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(1)); - - return true; -} - -static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char *nodeName) -{ - const gert::StorageShape *tpRecvCountsStorageShape = context->GetOptionalInputShape(TP_RECV_COUNTS_INDEX); - OPS_CHECK(tpRecvCountsStorageShape == nullptr, OPS_LOG_E(nodeName, "tpRecvCounts is null."), return false); - OPS_CHECK(tpRecvCountsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, - OPS_LOG_E(nodeName, "tpRecvCounts must be 1-dimension, but got %lu dim", - tpRecvCountsStorageShape->GetStorageShape().GetDimNum()), return false); - OPS_LOG_D(nodeName, "tpRecvCounts dim0 = %ld", tpRecvCountsStorageShape->GetStorageShape().GetDim(0)); - - return true; -} - -static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName) -{ - const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX); - OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x is null."), return false); - OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "x must be 2-dimension, but got %lu dim", xStorageShape->GetStorageShape().GetDimNum()), - return false); - OPS_LOG_D(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0)); - OPS_LOG_D(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1)); - - return true; -} - -static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName) -{ - OPS_CHECK(!CheckInputTensorDim(context, nodeName), - OPS_LOG_E(nodeName, "param shape of input tensor is invalid"), return false); - - OPS_CHECK(!CheckOptionalInputTensorDim(context, nodeName), - OPS_LOG_E(nodeName, "param shape of optional input tensor is invalid"), return false); - - OPS_CHECK(!CheckOutputTensorDim(context, nodeName), - OPS_LOG_E(nodeName, "param shape of output tensor is invalid"), return false); - - return true; -} - -// Validate data type -static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName) -{ - auto recvXDesc = context->GetInputDesc(RECV_X_INDEX); - OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false); - OPS_CHECK((recvXDesc->GetDataType() != ge::DT_BF16) && (recvXDesc->GetDataType() != ge::DT_FLOAT16), - OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is " - ), return false); - auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX); - OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false); - OPS_CHECK((tokenSrcInfoDesc->GetDataType() != ge::DT_INT32), OPS_LOG_E(nodeName, "tokenSrcInfoForCombine dataType is invalid," - " dataType should be int32, but is"), return false); - auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX); - OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false); - OPS_CHECK((tpRecvCountsDesc->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, "tpRecvCounts dataType is invalid, dataType should be int32, but is "), return false); - auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX); - OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false); - OPS_CHECK((topkWeightsDesc->GetDataType() != ge::DT_FLOAT), - OPS_LOG_E(nodeName, "topkWeights dataType is invalid, dataType should be float, but is "), - return false); - auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX); - OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false); - OPS_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()), OPS_LOG_E(nodeName, - "x dataType is invalid, dataType should be equal to recvX dataType , but is "), - return false); - return true; -} - -static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName) -{ - auto recvXDesc = context->GetInputDesc(RECV_X_INDEX); - OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(recvXDesc->GetStorageFormat())) == - ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "recvXFormat is invalid"), return false); - - auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX); - OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(tokenSrcInfoDesc->GetStorageFormat())) == - ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tokenSrcInfoFormat is invalid"), return false); - - auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX); - OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(tpRecvCountsDesc->GetStorageFormat())) == - ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tpRecvCountsFormat is invalid"), return false); - - auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX); - OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(topkWeightsDesc->GetStorageFormat())) == - ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "topkWeightsFormat is invalid"), return false); - - auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX); - OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "xFormat is invalid"), return false); - - return true; -} - -static bool CheckTensorShape(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData, - const char *nodeName, uint32_t localExpertNum) -{ - const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX); - int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0); - int64_t topkWeightsDim1 = topkWeightsStorageShape->GetStorageShape().GetDim(1); - int64_t moeExpertNum = static_cast(tilingData.moeCombineNormalInfo.moeExpertNum); - OPS_CHECK((topkWeightsDim1 <= 0) || (topkWeightsDim1 > K_MAX || (topkWeightsDim1 > moeExpertNum)), - OPS_LOG_E(nodeName, "topkWeights's dim1(K) should be in (0, min(%ld, moeExpertNum %ld)], " - "but got topkWeights's dim1=%ld.", K_MAX, moeExpertNum, topkWeightsDim1), return false); - tilingData.moeCombineNormalInfo.k = static_cast(topkWeightsDim1); - - // Validate recvX dimensions and set h - int64_t tpWorldSize = static_cast(tilingData.moeCombineNormalInfo.tpWorldSize); - const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX); - int64_t recvXDim1 = recvXStorageShape->GetStorageShape().GetDim(1); - OPS_CHECK((recvXDim1 < H_MIN) || (recvXDim1 > H_MAX), - OPS_LOG_E(nodeName, "recvX's dim1(H) should be in [%ld, %ld], but got %ld.", - H_MIN, H_MAX, recvXDim1), return false); // 32-byte aligned - tilingData.moeCombineNormalInfo.h = static_cast(recvXDim1); - - // Validate epRecvCount and tpRecvCount dimensions - int64_t epWorldSize = static_cast(tilingData.moeCombineNormalInfo.epWorldSize); - int64_t moeExpertPerRankNum = static_cast(tilingData.moeCombineNormalInfo.moeExpertPerRankNum); - - // Validate x dimensions - const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX); - int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); - int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); - OPS_CHECK(xDim0 != topkWeightsDim0, OPS_LOG_E(nodeName, - "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false); - OPS_CHECK(xDim1 != recvXDim1, OPS_LOG_E(nodeName, - "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), return false); - - return true; -} - -static bool CheckAttrs(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData, - const char *nodeName, uint32_t &localMoeExpertNum) -{ - uint32_t epWorldSize = tilingData.moeCombineNormalInfo.epWorldSize; - uint32_t tpWorldSize = tilingData.moeCombineNormalInfo.tpWorldSize; - uint32_t moeExpertNum = tilingData.moeCombineNormalInfo.moeExpertNum; - - // Validate if moe expert number can be evenly distributed across multiple machines - OPS_CHECK(moeExpertNum % epWorldSize != 0, - OPS_LOG_E(nodeName, "moeExpertNum should be divisible by epWorldSize, " - "but got moeExpertNum=%d, epWorldSize=%d.", moeExpertNum, epWorldSize), return false); - localMoeExpertNum = moeExpertNum / epWorldSize; - OPS_CHECK(localMoeExpertNum <= 0, - OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum), return false); - // Validate if expert number per card equals 1 when tp=2 - OPS_CHECK((localMoeExpertNum > 1) && (tpWorldSize > 1), - OPS_LOG_E(nodeName, "Cannot support multi-moeExpert %d in a rank when tpWorldSize = %d > 1", - localMoeExpertNum, tpWorldSize), return false); - tilingData.moeCombineNormalInfo.moeExpertPerRankNum = localMoeExpertNum; - - // Validate topkWeights dimension 0 and set bs - const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX); - int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0); - OPS_CHECK((topkWeightsDim0 <= 0) || (topkWeightsDim0 > BS_UPPER_BOUND), - OPS_LOG_E(nodeName, "Invalid topkWeights dims0(BS) %ld. Should be between [1, %ld].", - topkWeightsDim0, BS_UPPER_BOUND), return false); - tilingData.moeCombineNormalInfo.bs = static_cast(topkWeightsDim0); - - // Validate globalBS - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return false); - auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); - OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBs is null."), return false); - OPS_LOG_D(nodeName, "MoeCombineNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n", - *globalBsPtr, topkWeightsDim0, epWorldSize); - - OPS_CHECK((*globalBsPtr != 0) && ((*globalBsPtr < static_cast(epWorldSize) * topkWeightsDim0) || - ((*globalBsPtr) % (static_cast(epWorldSize)) != 0)), OPS_LOG_E(nodeName, "globalBS is invalid, only " - "support 0 or maxBs(maxBs is the largest bs on all ranks) * epWorldSize, but got globalBS=%ld, " - "bs=%ld, epWorldSize=%u.", *globalBsPtr, topkWeightsDim0, epWorldSize), return false); - - tilingData.moeCombineNormalInfo.globalBs = static_cast(*globalBsPtr); - if (*globalBsPtr == 0) { - tilingData.moeCombineNormalInfo.globalBs = static_cast(topkWeightsDim0) * epWorldSize; - } - - return true; -} - -static ge::graphStatus TilingCheckMoeCombineNormal(gert::TilingContext *context, const char *nodeName) -{ - // Check parameter shape information - OPS_CHECK(!CheckTensorDim(context, nodeName), - OPS_LOG_E(nodeName, "param shape is invalid"), return ge::GRAPH_FAILED); - // Check parameter dataType information - OPS_CHECK(!CheckTensorDataType(context, nodeName), - OPS_LOG_E(nodeName, "param dataType is invalid"), return ge::GRAPH_FAILED); - // Check parameter format information - OPS_CHECK(!CheckTensorFormat(context, nodeName), - OPS_LOG_E(nodeName, "param Format is invalid"), return ge::GRAPH_FAILED); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus SetWorkspace(gert::TilingContext *context, const char *nodeName) -{ - size_t *workspace = context->GetWorkspaceSizes(1); - OPS_CHECK(workspace == nullptr, OPS_LOG_E(nodeName, "get workspace failed"), - return ge::GRAPH_FAILED); - workspace[0] = SYSTEM_NEED_WORKSPACE; - OPS_LOG_D(nodeName, "workspce[0] size is %ld", workspace[0]); - return ge::GRAPH_SUCCESS; -} - - -static void SetHCommCfg(gert::TilingContext *context, MoeCombineNormalTilingData *tiling, - const std::string groupEp, const std::string groupTp) -{ - const char* nodeName = context->GetNodeName(); - OPS_LOG_D(nodeName, "MoeCombineNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str()); - uint32_t opType1 = OP_TYPE_ALL_TO_ALL; - uint32_t opType2 = OP_TYPE_REDUCE_SCATTER; - std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; - std::string algConfigReduceScatterStr = "ReduceScatter=level0:ring"; - - AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr); - mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); - mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1); - - mc2CcTilingConfig.SetGroupName(groupTp); - mc2CcTilingConfig.SetOpType(opType2); - mc2CcTilingConfig.SetAlgConfig(algConfigReduceScatterStr); - mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2); -} - -static ge::graphStatus MoeCombineNormalA3TilingFuncImpl(gert::TilingContext* context) -{ - const char *nodeName = context->GetNodeName(); - OPS_LOG_D(nodeName, "Enter MoeCombineNormal Tiling func"); - MoeCombineNormalTilingData *tilingData = context->GetTilingData(); - OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); - std::string groupEp = ""; - std::string groupTp = ""; - uint32_t localMoeExpertNum = 1; - - // Get input parameter attributes - OPS_CHECK(GetAttrAndSetTilingData(context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED, - OPS_LOG_E(nodeName, "Getting attr failed."), return ge::GRAPH_FAILED); - - // Check input/output dim, format, dataType - OPS_CHECK(TilingCheckMoeCombineNormal(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling check params failed"), return ge::GRAPH_FAILED); - - // Check if attribute values are valid - OPS_CHECK(!CheckAttrs(context, *tilingData, nodeName, localMoeExpertNum), - OPS_LOG_E(nodeName, "attr check failed."), return ge::GRAPH_FAILED); - - uint32_t epRankId = tilingData->moeCombineNormalInfo.epRankId; - - // Check shape dimensions and assign h, k - OPS_CHECK(!CheckTensorShape(context, *tilingData, nodeName, localMoeExpertNum), - OPS_LOG_E(nodeName, "param dim check failed."), return ge::GRAPH_FAILED); - - // Validate win area size - uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); - uint64_t h = static_cast(tilingData->moeCombineNormalInfo.h); - uint64_t epWorldSize = static_cast(tilingData->moeCombineNormalInfo.epWorldSize); - uint64_t k = static_cast(tilingData->moeCombineNormalInfo.k); - uint64_t maxBs = static_cast(tilingData->moeCombineNormalInfo.globalBs)/ epWorldSize; - // Combine data area: token start address aligned to 512 - uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - // Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b) - uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_RECV_IDX_BUFFER; - uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET) * - DOUBLE_DATA_BUFFER; - OPS_CHECK((actualSize > maxWindowSize), - OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u," - " tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE(" - "((maxBs * tokenNeedSizeDispatch) + (maxBs * tokenNeedSizeCombine * k) + 3MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.", - maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k, - actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE), - return ge::GRAPH_FAILED); - tilingData->moeCombineNormalInfo.totalWinSize = maxWindowSize; - - OPS_CHECK(SetWorkspace(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(context->GetNodeName(), "Tiling set workspace Failed"), - return ge::GRAPH_FAILED); - - SetHCommCfg(context, tilingData, groupEp, groupTp); - - uint64_t tpWorldSize = static_cast(tilingData->moeCombineNormalInfo.tpWorldSize); - - uint32_t blockDim = 1U; - auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - uint64_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint64_t ubSize = 0UL; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); - blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); - context->SetBlockDim(blockDim); - tilingData->moeCombineNormalInfo.aivNum = aivNum; - tilingData->moeCombineNormalInfo.totalUbSize = ubSize; - context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously - OPS_LOG_D(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize); - PrintTilingDataInfo(nodeName, *tilingData); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus MoeCombineNormalTilingFunc(gert::TilingContext* context) -{ - // recvX data type int32 is not supported - auto recvXDesc = context->GetInputDesc(RECV_X_INDEX); - const char *nodeName = context->GetNodeName(); - OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return ge::GRAPH_FAILED); - // Check if recvX data type is DT_INT32 - OPS_CHECK((recvXDesc->GetDataType() == ge::DT_INT32), - OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is "), - return ge::GRAPH_FAILED); - - ge::graphStatus ret = MoeCombineNormalA3TilingFuncImpl(context); - return ret; -} - -struct MoeCombineNormalCompileInfo {}; -ge::graphStatus TilingParseForMoeCombineNormal(gert::TilingParseContext *context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_OPTILING(MoeCombineNormal) - .Tiling(MoeCombineNormalTilingFunc) - .TilingParse(TilingParseForMoeCombineNormal); -} // namespace optiling diff --git a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp b/csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp deleted file mode 100644 index 61a23c3b..00000000 --- a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "kernel_operator.h" -#include "lib/matmul_intf.h" -#include "moe_combine_normal.h" -#include "moe_combine_normal_tiling.h" -using namespace AscendC; -using namespace MoeCombineNormalImpl; - -extern "C" __global__ __aicore__ void moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut, - GM_ADDR workspaceGM, GM_ADDR tilingGM) - -{ - REGISTER_TILING_DEFAULT(MoeCombineNormalTilingData); - TPipe pipe; - -#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16) - GET_TILING_DATA_WITH_STRUCT(MoeCombineNormalTilingData, tilingData, tilingGM); - MoeCombineNormal op; - op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, workspaceGM, &pipe, &tilingData); - op.Process(); -#endif -} \ No newline at end of file diff --git a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h b/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h deleted file mode 100644 index 156e7248..00000000 --- a/csrc/moe_combine_normal/op_kernel/moe_combine_normal.h +++ /dev/null @@ -1,377 +0,0 @@ -#ifndef MOE_COMBINE_NORMAL_H -#define MOE_COMBINE_NORMAL_H - -#include "kernel_operator.h" -#include "kernel_tiling/kernel_tiling.h" -#include "../common/moe_distribute_base.h" -#include "moe_combine_normal_tiling.h" - -namespace MoeCombineNormalImpl { -constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U; -constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U; -constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U; -constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL; -constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL; -constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U; -constexpr uint32_t UB_32_ALIGN = 32U; -constexpr uint32_t MUL_256_ALIGN = 256U; -constexpr uint64_t WIN_512_ALIGN = 512UL; -constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U; -constexpr uint8_t DOUBLE_BUFFER = 2; - -template -__aicore__ inline void SyncFunc() -{ - int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); - AscendC::SetFlag(eventID); - AscendC::WaitFlag(eventID); -} - -#define TemplateMC2TypeClass typename RecvXType, typename XType, typename SrcInfoType -#define TemplateMC2TypeFunc RecvXType, XType, SrcInfoType - -using namespace AscendC; -template -class MoeCombineNormal { -public: - __aicore__ inline MoeCombineNormal() {}; - __aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR tpRecvCount,GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, - const MoeCombineNormalTilingData *tilingData); - __aicore__ inline void Process(); -private: - __aicore__ inline void InitMagic(); - __aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR XOut); - __aicore__ inline void InitTilingData(const MoeCombineNormalTilingData *tilingData); - __aicore__ inline void InitBuffLen(); - __aicore__ inline void CopyBufferToShareAndSetStatus(); - __aicore__ inline void CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex); - __aicore__ inline void ReadBufferFromRemote(); - __aicore__ inline void WaitBuffCopy(uint32_t tokenIndex); - __aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId); - __aicore__ inline void ReadBufferAndWeightedSum(uint32_t tokenIndex, uint32_t startTokenIndex); - - __aicore__ GM_ADDR GetStateAddrByRankId(const int32_t rankId) - { - GM_ADDR bufferAddr; - if (epRankId_ == rankId) { - bufferAddr = (GM_ADDR)epWinContext_->localWindowsIn; - } else { - bufferAddr = (GM_ADDR)((HcclRankRelationResV2 *)epWinContext_->remoteRes[rankId].nextDevicePtr)->windowsIn; - } - return (GM_ADDR)(bufferAddr + winDataSizeOffset_); - } - - __aicore__ GM_ADDR GetBufferAddrByRankId(const int32_t rankId) - { - return GetStateAddrByRankId(rankId) + COMBINE_STATE_WIN_OFFSET; - } - - __aicore__ inline void SplitCoreCal(uint32_t totalNum, uint32_t &perCoreNum, uint32_t &startIdx, uint32_t &endIdx) - { - perCoreNum = totalNum / aivNum_; - uint32_t remainderRankNum = totalNum % aivNum_; - - startIdx = perCoreNum * coreIdx_; - if (coreIdx_ < remainderRankNum) { - perCoreNum++; - startIdx += coreIdx_; - } else { - startIdx += remainderRankNum; - } - endIdx = startIdx + perCoreNum; - } - - __gm__ HcclOpResParam *epWinContext_{nullptr}; - __gm__ HcclOpResParam *tpWinContext_{nullptr}; - uint32_t axisBS_{0}; - uint32_t axisH_{0}; - uint32_t axisK_{0}; - uint32_t aivNum_{0}; - uint32_t epWorldSize_{0}; - uint32_t epRankId_{0}; - uint32_t coreIdx_{0}; - uint32_t moeExpertNum_{0}; - uint32_t moeExpertPerRankNum_{0}; - uint32_t magic_{0}; - uint64_t winDataSizeOffset_{0}; - uint32_t selfSendCnt_{0}; - uint32_t hRecvXTypeLen_{0}; - uint32_t h32AlignFloatLen_{0}; - uint32_t h256AlignFloatLen_{0}; - uint32_t h32AlignRecvXLen_{0}; - uint32_t h512AlignRecvXLen_{0}; - - TPipe *tpipe_{nullptr}; - TQue weightedSumQueue_; - TQueBind localCopyQueue_; - TBuf<> stateBuf_; - TBuf<> topkWeightsBuf_; - TBuf<> tokenFloatBuf_; - TBuf<> sumFloatBuf_; - TBuf<> weightedMulBuf_; - TBuf<> srcInfoBuf_; - TBuf<> xOutBuf_; - TBuf<> tempStateBuf_; - - GlobalTensor recvXGM_; - GlobalTensor tokenSrcInfoGM_; - GlobalTensor epRecvCountGM_; - GlobalTensor topkWeightsGM_; - GlobalTensor xOutGlobal_; - GM_ADDR localRankGM_; - GM_ADDR workspaceGM_; -}; - -template -__aicore__ inline void MoeCombineNormal::InitMagic() -{ - auto contextGM0 = AscendC::GetHcclContext(); - epWinContext_ = (__gm__ HcclOpResParam*)contextGM0; - - GlobalTensor selfMagicTensor; - selfMagicTensor.SetGlobalBuffer((__gm__ int32_t*)((GM_ADDR)epWinContext_->localWindowsExp + MAGIC_WIN_OFFSET + - coreIdx_ * WIN_512_ALIGN)); - DataCacheCleanAndInvalid(selfMagicTensor); - magic_ = selfMagicTensor(0); - selfMagicTensor(0) = ((magic_ == 0) ? 1 : 0); - DataCacheCleanAndInvalid(selfMagicTensor); -} - -template -__aicore__ inline void MoeCombineNormal::InitGlobalBuffer( - GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR XOut) -{ - recvXGM_.SetGlobalBuffer((__gm__ RecvXType*)recvX); - tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType*)tokenSrcInfo); - epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t*)epRecvCount); - topkWeightsGM_.SetGlobalBuffer((__gm__ float*)topkWeights); - xOutGlobal_.SetGlobalBuffer((__gm__ XType*)XOut); -} - -template -__aicore__ inline void MoeCombineNormal::InitTilingData(const MoeCombineNormalTilingData *tilingData) -{ - axisBS_ = tilingData->moeCombineNormalInfo.bs; - axisH_ = tilingData->moeCombineNormalInfo.h; - axisK_ = tilingData->moeCombineNormalInfo.k; - aivNum_ = tilingData->moeCombineNormalInfo.aivNum; - moeExpertNum_ = tilingData->moeCombineNormalInfo.moeExpertNum; - moeExpertPerRankNum_ = tilingData->moeCombineNormalInfo.moeExpertPerRankNum; - epWorldSize_ = tilingData->moeCombineNormalInfo.epWorldSize; - epRankId_ = tilingData->moeCombineNormalInfo.epRankId; -} - -template -__aicore__ inline void MoeCombineNormal::InitBuffLen() -{ - uint32_t hFloatSize = axisH_ * static_cast(sizeof(float)); - h32AlignFloatLen_ = Ceil(hFloatSize, UB_32_ALIGN) * UB_32_ALIGN; - h256AlignFloatLen_ = Ceil(hFloatSize, MUL_256_ALIGN) * MUL_256_ALIGN; - hRecvXTypeLen_ = axisH_ * sizeof(RecvXType); - h32AlignRecvXLen_ = Ceil(hRecvXTypeLen_, UB_32_ALIGN) * UB_32_ALIGN; - h512AlignRecvXLen_ = Ceil(hRecvXTypeLen_, WIN_512_ALIGN) * WIN_512_ALIGN; -} - -template -__aicore__ inline void MoeCombineNormal::Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, - GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR tpRecvCount, GM_ADDR XOut, - GM_ADDR workspaceGM, TPipe *pipe, - const MoeCombineNormalTilingData *tilingData) -{ - workspaceGM_ = workspaceGM; - tpipe_ = pipe; - coreIdx_ = GetBlockIdx(); - - InitMagic(); - InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, XOut); - InitTilingData(tilingData); - InitBuffLen(); - - PipeBarrier(); - winDataSizeOffset_ = static_cast(magic_) * (tilingData->moeCombineNormalInfo.totalWinSize / 2UL); - localRankGM_ = GetBufferAddrByRankId(epRankId_); - DataCacheCleanAndInvalid(epRecvCountGM_[moeExpertNum_ - 1]); - selfSendCnt_ = epRecvCountGM_(moeExpertNum_ - 1); -} - -template -__aicore__ inline void MoeCombineNormal::CopyBufferToShareAndSetStatus() -{ - PipeBarrier(); - uint32_t perBlockSendNum = 0, startTokenId = 0, endTokenId = 0; - SplitCoreCal(selfSendCnt_, perBlockSendNum, startTokenId, endTokenId); - if (perBlockSendNum == 0U) { - return; - } - - uint32_t blockLen = static_cast(perBlockSendNum * TOKEN_SRC_INFO_LEN * sizeof(uint32_t)); - tpipe_->Reset(); - tpipe_->InitBuffer(stateBuf_, UB_32_ALIGN); - tpipe_->InitBuffer(localCopyQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_); - tpipe_->InitBuffer(srcInfoBuf_, blockLen); - LocalTensor statusTensor = stateBuf_.AllocTensor(); - Duplicate(statusTensor, 0x3F800000, FLOAT_NUM_PER_ALIGN); - - LocalTensor srcInfoLocal = srcInfoBuf_.Get(); - const DataCopyExtParams dataCopyParams{1U, blockLen, 0U, 0U, 0U}; - const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; - DataCopyPad(srcInfoLocal, tokenSrcInfoGM_[startTokenId * TOKEN_SRC_INFO_LEN], dataCopyParams, padParams); - - SyncFunc(); - for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) { - uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN; - uint32_t srcRankId = static_cast(srcInfoLocal(index + RANK_ID_OFFSET_IN_SRC_INFO)); - uint32_t srcTokenId = static_cast(srcInfoLocal(index + TOKEN_IDX_OFFSET_IN_SRC_INFO)); - uint32_t srcTopkId = static_cast(srcInfoLocal(index + TOPK_IDX_OFFSET_IN_SRC_INFO)); - CopyBufferToShare(srcRankId, srcTokenId, srcTopkId, tokenIndex); - PipeBarrier(); - SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId); - } - SyncFunc(); -} - -template -__aicore__ inline void MoeCombineNormal::CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, - uint32_t srcTopkId, uint32_t tkIndex) -{ - uint32_t tokenOffset = tkIndex * axisH_; - GM_ADDR dstGM = GetBufferAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_; - GlobalTensor dstWindow; - dstWindow.SetGlobalBuffer((__gm__ XType*)dstGM); - DataCopyExtParams xOutCopyParams{1U, static_cast(hRecvXTypeLen_), 0U, 0U, 0U}; - DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - - LocalTensor localCopyTensor; - localCopyTensor = localCopyQueue_.AllocTensor(); - DataCopyPad(localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams); - localCopyQueue_.EnQue(localCopyTensor); - localCopyTensor = localCopyQueue_.DeQue(); - DataCopyPad(dstWindow, localCopyTensor, xOutCopyParams); - localCopyQueue_.FreeTensor(localCopyTensor); -} - -template -__aicore__ inline void MoeCombineNormal::SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, - uint32_t srcTopkId) -{ - LocalTensor statusTensor = stateBuf_.AllocTensor(); - GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN; - GlobalTensor stateGMTensor; - stateGMTensor.SetGlobalBuffer((__gm__ uint32_t*)stateGM); - DataCopy(stateGMTensor, statusTensor, FLOAT_NUM_PER_ALIGN); -} - -template -__aicore__ inline void MoeCombineNormal::WaitBuffCopy(uint32_t tokenIndex) -{ - uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN; - GM_ADDR stateGM = GetStateAddrByRankId(epRankId_) + tokenIndex * axisK_ * UB_32_ALIGN; // Calculate address offset - GlobalTensor stateGMTensor; - stateGMTensor.SetGlobalBuffer((__gm__ float*)stateGM); - float current = (float)0.0; - float target = (float)1.0 * axisK_ * FLOAT_NUM_PER_ALIGN; - SumParams sumPerKParams{1, calCount, calCount}; - LocalTensor stateTensorLocal = stateBuf_.Get(); - LocalTensor tempStateTensorLocal = tempStateBuf_.Get(); - while (current != target) { - SyncFunc(); - DataCopy(stateTensorLocal, stateGMTensor, calCount); - SyncFunc(); - Sum(tempStateTensorLocal, stateTensorLocal, sumPerKParams); - SyncFunc(); - current = tempStateTensorLocal(0); - } - SyncFunc(); - Duplicate(tempStateTensorLocal, (float)0.0, calCount); - SyncFunc(); - DataCopy(stateGMTensor, tempStateTensorLocal, calCount); -} - -template -__aicore__ inline void MoeCombineNormal::ReadBufferAndWeightedSum(uint32_t tokenIndex, - uint32_t startTokenIndex) -{ - LocalTensor tokenFloatLocal = tokenFloatBuf_.Get(); - LocalTensor weightedMulBufLocal = weightedMulBuf_.Get(); - LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); - LocalTensor topkWeightsLocal = topkWeightsBuf_.Get(); - LocalTensor stateTensorLocal = stateBuf_.Get(); - Duplicate(sumFloatBufLocal, static_cast(0), axisH_); - const DataCopyExtParams xOutCopyParams{1U, static_cast(hRecvXTypeLen_), 0U, 0U, 0U}; - - for (uint32_t topkId = 0U; topkId < axisK_; topkId++) { - float scale = topkWeightsLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId); - GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_; - GlobalTensor localTokenTensor; - localTokenTensor.SetGlobalBuffer((__gm__ XType*)localTokenAddr); - - LocalTensor tmpToken = weightedSumQueue_.AllocTensor(); - const DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(tmpToken, localTokenTensor, xOutCopyParams, copyPadExtParams); - weightedSumQueue_.EnQue(tmpToken); - tmpToken = weightedSumQueue_.DeQue(); - Cast(tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_); - PipeBarrier(); - AscendC::Muls(weightedMulBufLocal, tokenFloatLocal, scale, axisH_); - PipeBarrier(); - AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, weightedMulBufLocal, axisH_); - weightedSumQueue_.FreeTensor(tmpToken); - } - PipeBarrier(); - LocalTensor xOutLocal = xOutBuf_.Get(); - Cast(xOutLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, axisH_); - SyncFunc(); - DataCopyPad(xOutGlobal_[tokenIndex * axisH_], xOutLocal, xOutCopyParams); -} - -template -__aicore__ inline void MoeCombineNormal::ReadBufferFromRemote() -{ - if (axisBS_ == 0U) { - return; - } - uint32_t tokenPerBlock = 0U, startTokenIndex = 0U, endTokenIndex = 0U; - SplitCoreCal(axisBS_, tokenPerBlock, startTokenIndex, endTokenIndex); - - if (tokenPerBlock == 0U) { - return; - } - - tpipe_->Reset(); - tpipe_->InitBuffer(xOutBuf_, h32AlignRecvXLen_); - tpipe_->InitBuffer(tokenFloatBuf_, h32AlignFloatLen_); - tpipe_->InitBuffer(weightedMulBuf_, h256AlignFloatLen_); - tpipe_->InitBuffer(sumFloatBuf_, h32AlignFloatLen_); - tpipe_->InitBuffer(weightedSumQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_); - tpipe_->InitBuffer(stateBuf_, (axisK_) * UB_32_ALIGN); - tpipe_->InitBuffer(tempStateBuf_, (axisK_) * UB_32_ALIGN); - tpipe_->InitBuffer(topkWeightsBuf_, tokenPerBlock * axisK_ * sizeof(float)); - - LocalTensor topkWeightsLocal = topkWeightsBuf_.Get(); - const DataCopyExtParams bskParams{1U, static_cast(tokenPerBlock * axisK_ * sizeof(float)), 0U, 0U, 0U}; - const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; - DataCopyPad(topkWeightsLocal, topkWeightsGM_[startTokenIndex * axisK_], bskParams, copyPadFloatParams); - SyncFunc(); - - for (uint32_t tokenIndex = startTokenIndex; tokenIndex < endTokenIndex; tokenIndex++) { - WaitBuffCopy(tokenIndex); - SyncFunc(); // Sync with result datacopy on same tensor - ReadBufferAndWeightedSum(tokenIndex, startTokenIndex); - } -} - -template -__aicore__ inline void MoeCombineNormal::Process() -{ - if ASCEND_IS_AIV { // All AIV processing - CopyBufferToShareAndSetStatus(); - ReadBufferFromRemote(); - } -} - -} // MoeCombineNormalImpl -#endif // MOE_COMBINE_IMPL_H diff --git a/csrc/moe_combine_normal/op_kernel/moe_combine_normal_tiling.h b/csrc/moe_combine_normal/op_kernel/moe_combine_normal_tiling.h deleted file mode 100644 index b7c02bf0..00000000 --- a/csrc/moe_combine_normal/op_kernel/moe_combine_normal_tiling.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef MOE_COMBINE_NORMAL_TILING_H -#define MOE_COMBINE_NORMAL_TILING_H - -#include -#include "kernel_tiling/kernel_tiling.h" - -// a3 -struct MoeCombineNormalInfo { - uint32_t epWorldSize; - uint32_t tpWorldSize; - uint32_t epRankId; - uint32_t tpRankId; - uint32_t expertShardType; - uint32_t moeExpertNum; - uint32_t moeExpertPerRankNum; - uint32_t globalBs; - uint32_t bs; - uint32_t k; - uint32_t h; - uint32_t aivNum; - uint64_t totalUbSize; - uint64_t totalWinSize; - float armAvgFactor; - float epsilon; -}; -struct MoeCombineNormalTilingData { - Mc2InitTiling mc2InitTiling; - Mc2CcTiling mc2CcTiling1; - Mc2CcTiling mc2CcTiling2; - MoeCombineNormalInfo moeCombineNormalInfo; -}; - -#endif //MOE_COMBINE_NORMAL_TILING_H \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_host/CMakeLists.txt b/csrc/moe_dispatch_normal/op_host/CMakeLists.txt deleted file mode 100644 index c6afc9f5..00000000 --- a/csrc/moe_dispatch_normal/op_host/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_ops_compile_options( - OP_NAME MoeDispatchNormal - OPTIONS --cce-auto-sync=off - -Wno-deprecated-declarations - -Werror -) - -target_sources(op_host_aclnnInner PRIVATE - moe_dispatch_normal.cpp -) - -target_sources(opapi PRIVATE - aclnn_moe_dispatch_normal.cpp -) - -if (NOT BUILD_OPEN_PROJECT) - target_sources(aclnn_ops_train PRIVATE - aclnn_moe_dispatch_normal.cpp - ) - - target_sources(aclnn_ops_infer PRIVATE - aclnn_moe_dispatch_normal.cpp - ) -endif () - -target_sources(optiling PRIVATE - moe_dispatch_normal_tiling.cpp -) - -target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_sources(opsproto PRIVATE) - -file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_dispatch_normal.h") - -install(FILES ${_GMM_Aclnn_header} - DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL -) diff --git a/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.cpp b/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.cpp deleted file mode 100644 index 85943a38..00000000 --- a/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include -#include "graph/types.h" -#include "aclnn_moe_dispatch_normal.h" - -enum NnopbaseHcclServerType { - NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, - NNOPBASE_HCCL_SERVER_TYPE_MTE, - NNOPBASE_HCCL_SERVER_TYPE_END -}; -extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); - -#ifdef __cplusplus -extern "C" { -#endif - -extern aclnnStatus aclnnInnerMoeDispatchNormalGetWorkspaceSize( - const aclTensor *x, - const aclTensor *topkIdx, - const aclTensor *sendOffset, - const aclTensor *sendTokenIdx, - const aclTensor *recvOffset, - const aclTensor *recvCount, - char *groupEp, - int64_t epWorldSize, - int64_t epRankId, - char *groupTpOptional, - int64_t tpWorldSize, - int64_t tpRankId, - int64_t moeExpertNum, - int64_t quantMode, - int64_t globalBs, - const aclTensor *recvX, - const aclTensor *recvXScales, - const aclTensor *assistInfoForCombine, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -extern aclnnStatus aclnnInnerMoeDispatchNormal( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x, const aclTensor *topkIdx, - const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset, const aclTensor *recvCount, - char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, int64_t tpWorldSize, int64_t tpRankId, - int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, const aclTensor *recvX, - const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, uint64_t *workspaceSize, - aclOpExecutor **executor) -{ - return aclnnInnerMoeDispatchNormalGetWorkspaceSize(x, - topkIdx, - sendOffset, - sendTokenIdx, - recvOffset, - recvCount, - groupEp, - epWorldSize, - epRankId, - groupTpOptional, - tpWorldSize, - tpRankId, - moeExpertNum, - quantMode, - globalBs, - recvX, - recvXScales, - assistInfoForCombine, - workspaceSize, - executor); -} - -aclnnStatus aclnnMoeDispatchNormal( - void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) -{ - if (NnopbaseSetHcclServerType) { - NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); - } - return aclnnInnerMoeDispatchNormal(workspace, workspaceSize, executor, stream); -} - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h b/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h deleted file mode 100644 index 0171db1a..00000000 --- a/csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef ACLNN_MOE_DISPATCH_NORMAL_H_ -#define ACLNN_MOE_DISPATCH_NORMAL_H_ - -#include "aclnn/acl_meta.h" - -#ifdef __cplusplus -extern "C" { -#endif - -__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x, - const aclTensor *topkIdx, const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset, - const aclTensor *recvCount, char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, - int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, - const aclTensor *recvX, const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, - uint64_t *workspaceSize, aclOpExecutor **executor); - -__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormal( - void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream); - -#ifdef __cplusplus -} -#endif - -#endif \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp b/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp deleted file mode 100644 index 5fb05995..00000000 --- a/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include "register/op_def_registry.h" - -namespace ops { -class MoeDispatchNormal : public OpDef { -public: - explicit MoeDispatchNormal(const char *name) : OpDef(name) - { - this->Input("x") - .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("topk_idx") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - - this->Input("send_offset") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("send_tokenIdx") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("recv_offset") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("recv_count") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - - this->Output("recv_x") - .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_INT8}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Output("x_scales") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Output("assist_info_for_combine") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Attr("group_ep").AttrType(REQUIRED).String(); - this->Attr("ep_world_size").AttrType(REQUIRED).Int(); - this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); - this->Attr("group_tp").AttrType(OPTIONAL).String(""); - this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0); - this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0); - this->Attr("moe_expert_num").AttrType(REQUIRED).Int(); - this->Attr("quant_mode").AttrType(OPTIONAL).Int(0); - this->Attr("global_bs").AttrType(OPTIONAL).Int(0); - - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("jitCompile.flag", "static_true") - .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); - - this->AICore().AddConfig("ascend910_93", aicore_config); - this->MC2().HcclGroup({"group_ep", "group_tp"}); - } -}; - -OP_ADD(MoeDispatchNormal); - -} // namespace ops \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp b/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp deleted file mode 100644 index 0d01272c..00000000 --- a/csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp +++ /dev/null @@ -1,635 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "register/tilingdata_base.h" -#include "tiling/tiling_api.h" -#include "log/ops_log.h" -#include "graph/utils/type_utils.h" -#include "register/op_def_registry.h" -#include "../op_kernel/moe_dispatch_normal_tiling.h" - -using namespace AscendC; -using namespace ge; -namespace { -class Mc2TilingUtils { -public: -#define HCCL_BUFFSIZE "HCCL_BUFFSIZE" - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - if (getenv(HCCL_BUFFSIZE) == nullptr) { - OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(HCCL_BUFFSIZE)); - defaultWindowSize = std::stoi(envStr); - } catch (const std::invalid_argument &ia) { - OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); - } catch (const std::out_of_range &oor) { - OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } -}; -constexpr uint32_t X_INDEX = 0U; -constexpr uint32_t EXPERT_IDS_INDEX = 1U; -constexpr uint32_t SEND_OFFSET_INDEX = 2U; -constexpr uint32_t SEND_TOKENIDX_INDEX = 3U; -constexpr uint32_t RECV_OFFSET_INDEX = 4U; -constexpr uint32_t RECV_COUNT_INDEX = 5U; - -constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0U; -constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1U; -constexpr uint32_t OUTPUT_ASSIST_INFO_INDEX = 2U; - -constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; -constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; -constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; -constexpr uint32_t ATTR_GROUP_TP_INDEX = 3; -constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4; -constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5; -constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6; -constexpr uint32_t ATTR_QUANT_MODE_INDEX = 7; -constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 8; - -constexpr uint32_t TWO_DIMS = 2; -constexpr uint32_t ONE_DIM = 1; -constexpr uint32_t DYNAMIC_SCALE_DIM_NUM = 1; -constexpr uint64_t INIT_TILINGKEY = 10000; -constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8; -constexpr uint32_t NO_SCALES = 0; -constexpr uint32_t DYNAMIC_SCALES = 2; -constexpr uint32_t OP_TYPE_ALL_GATHER = 6; - -constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL; -constexpr int64_t MAX_EP_WORLD_SIZE = 384; -constexpr int64_t MIN_EP_WORLD_SIZE = 2; -constexpr int64_t MAX_TP_WORLD_SIZE = 2; -constexpr int64_t BS_UPPER_BOUND = 8000; // Maximum bs - -constexpr uint32_t TILINGKEY_TP_WORLD_SIZE = 100; -constexpr uint32_t TP_WORLD_SIZE_TWO = 2; -constexpr int64_t MOE_EXPERT_MAX_NUM = 512; -constexpr int64_t K_MAX = 16; -constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; -constexpr uint32_t WORKSPACE_ELEMENT_OFFSET = 512; -constexpr int64_t H_MIN = 1024; -constexpr int64_t H_MAX = 7168; -constexpr uint64_t MB_SIZE = 1024UL * 1024UL; -constexpr uint64_t TRIPLE = 3; -constexpr uint64_t WIN_ADDR_ALIGN = 512UL; -constexpr uint64_t SCALE_EXPAND_IDX_BUFFER = 44UL; // scale32B + 3*4expandIdx -constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL; -constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL; -constexpr uint64_t UB_ALIGN = 32UL; -constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL; -} // namespace - -namespace optiling { -static void PrintTilingDataInfo(const char *nodeName, MoeDispatchNormalTilingData &tilingData) -{ - OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeDispatchNormalInfo.epWorldSize); - OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeDispatchNormalInfo.tpWorldSize); - OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeDispatchNormalInfo.epRankId); - OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeDispatchNormalInfo.tpRankId); - OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeDispatchNormalInfo.moeExpertNum); - OPS_LOG_D(nodeName, "quantMode is %u.", tilingData.moeDispatchNormalInfo.quantMode); - OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeDispatchNormalInfo.globalBs); - OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeDispatchNormalInfo.bs); - OPS_LOG_D(nodeName, "k is %u.", tilingData.moeDispatchNormalInfo.k); - OPS_LOG_D(nodeName, "h is %u.", tilingData.moeDispatchNormalInfo.h); - OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeDispatchNormalInfo.aivNum); - OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeDispatchNormalInfo.totalUbSize); - OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeDispatchNormalInfo.totalWinSize); -} - -static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) -{ - const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX); - OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "xShape is null."), return false); - OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, - "xShape dims must be 2, but current dim num is %lu.", - xStorageShape->GetStorageShape().GetDimNum()), - return false); - int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); - int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); - OPS_LOG_D(nodeName, "x dim0 = %ld", xDim0); - OPS_LOG_D(nodeName, "x dim1 = %ld", xDim1); - - const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX); - OPS_CHECK(expertIdStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIdShape is null."), return false); - OPS_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, - "expertIdShape dims must be 2, but current dim num is %lu.", - expertIdStorageShape->GetStorageShape().GetDimNum()), - return false); - OPS_LOG_D(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0)); - OPS_LOG_D(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1)); - - const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX); - OPS_CHECK(expandXStorageShape == nullptr, OPS_LOG_E(nodeName, "expandXShape is null."), return false); - OPS_CHECK(expandXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, - "expandXShape dims must be 2, but current dim num is %lu.", - expandXStorageShape->GetStorageShape().GetDimNum()), - return false); - OPS_LOG_D(nodeName, "expandX dim0 = %ld", expandXStorageShape->GetStorageShape().GetDim(0)); - OPS_LOG_D(nodeName, "expandX dim1 = %ld", expandXStorageShape->GetStorageShape().GetDim(1)); - - if (quantMode == DYNAMIC_SCALES) { - const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX); - OPS_CHECK( - dynamicScalesStorageShape == nullptr, OPS_LOG_E(nodeName, "dynamicScalesShape is null."), return false); - OPS_CHECK(dynamicScalesStorageShape->GetStorageShape().GetDimNum() != DYNAMIC_SCALE_DIM_NUM, - OPS_LOG_E(nodeName, - "dynamicScalesShape dims must be %u, but current dim num is %lu.", - DYNAMIC_SCALE_DIM_NUM, - dynamicScalesStorageShape->GetStorageShape().GetDimNum()), - return false); - OPS_LOG_D(nodeName, "dynamicScales dim0 = %ld", dynamicScalesStorageShape->GetStorageShape().GetDim(0)); - } - - const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX); - OPS_CHECK(assistInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "assistInfoShape is null."), return false); - OPS_CHECK(assistInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, - OPS_LOG_E(nodeName, - "assistInfoShape dims must be 1, but current dim num is %lu.", - assistInfoStorageShape->GetStorageShape().GetDimNum()), - return false); - OPS_LOG_D(nodeName, "assistInfoForCombine dim0 = %ld", assistInfoStorageShape->GetStorageShape().GetDim(0)); - - return true; -} - -static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) -{ - auto xDesc = context->GetInputDesc(X_INDEX); - OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false); - OPS_CHECK((xDesc->GetDataType() != ge::DT_BF16) && (xDesc->GetDataType() != ge::DT_FLOAT16), - OPS_LOG_E(nodeName, "x dataType is invalid, dataType should be bf16 or float16, but is ."), - return false); - - auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX); - OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false); - OPS_CHECK(expertIdDesc->GetDataType() != ge::DT_INT32, - OPS_LOG_E(nodeName, "expertId dataType is invalid, dataType should be int32, but is ."), - return false); - - auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX); - OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false); - if (quantMode != NO_SCALES) { - OPS_CHECK(expandXDesc->GetDataType() != ge::DT_INT8, - OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be int8, but is."), - return false); - } else { - OPS_CHECK(expandXDesc->GetDataType() != xDesc->GetDataType(), - OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be equal to x dataType , but is."), - return false); - } - - if (quantMode == DYNAMIC_SCALES) { - auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX); - OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false); - OPS_CHECK(dynamicScalesDesc->GetDataType() != ge::DT_FLOAT, - OPS_LOG_E(nodeName, "dynamicScales dataType is invalid, dataType should be float, but is ."), - return false); - } - - auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX); - OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false); - OPS_CHECK(assistInfoDesc->GetDataType() != ge::DT_INT32, - OPS_LOG_E(nodeName, "assistInfoForCombine dataType is invalid, dataType should be int32, but is ."), - return false); - - return true; -} - -static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) -{ - auto xDesc = context->GetInputDesc(X_INDEX); - OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "x format is invalid."), - return false); - - auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX); - OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false); - OPS_CHECK( - static_cast(ge::GetPrimaryFormat(expertIdDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "expertId format is invalid."), - return false); - - auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX); - OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false); - OPS_CHECK( - static_cast(ge::GetPrimaryFormat(expandXDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "expandX format is invalid."), - return false); - - if (quantMode == DYNAMIC_SCALES) { - auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX); - OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false); - OPS_CHECK(static_cast(ge::GetPrimaryFormat(dynamicScalesDesc->GetStorageFormat())) == - ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "dynamicScales format is invalid."), - return false); - } - - auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX); - OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false); - OPS_CHECK( - static_cast(ge::GetPrimaryFormat(assistInfoDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OPS_LOG_E(nodeName, "assistInfoForCombine format is invalid."), - return false); - - return true; -} - -static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, - MoeDispatchNormalTilingData &tilingData, std::string &groupEp, std::string &groupTp) -{ - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); - - auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); - auto groupTpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_TP_INDEX)); - auto epWorldSizePtr = attrs->GetAttrPointer(ATTR_EP_WORLD_SIZE_INDEX); - auto tpWorldSizePtr = attrs->GetAttrPointer(ATTR_TP_WORLD_SIZE_INDEX); - auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); - auto tpRankIdPtr = attrs->GetAttrPointer(ATTR_TP_RANK_ID_INDEX); - auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); - auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); - - // Check for null - OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) || - (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), - OPS_LOG_E(nodeName, "groupEpPtr is null."), - return ge::GRAPH_FAILED); - OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSizePtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSizePtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankIdPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(quantModePtr == nullptr, OPS_LOG_E(nodeName, "quantModePtr is null."), return ge::GRAPH_FAILED); - - // Check if it meets uint32_t and other constraints - int64_t moeExpertNum = *moeExpertNumPtr; - int64_t epWorldSize = *epWorldSizePtr; - OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE), - OPS_LOG_E(nodeName, - "epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.", - MIN_EP_WORLD_SIZE, - MAX_EP_WORLD_SIZE, - epWorldSize), - return ge::GRAPH_FAILED); - OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE), - OPS_LOG_E(nodeName, - "tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.", - MAX_TP_WORLD_SIZE, - *tpWorldSizePtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize), - OPS_LOG_E( - nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.", epWorldSize, *epRankIdPtr), - return ge::GRAPH_FAILED); - if (*tpWorldSizePtr > 1) { - OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr), - OPS_LOG_E(nodeName, - "tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.", - *tpWorldSizePtr, - *tpRankIdPtr), - return ge::GRAPH_FAILED); - OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) || - (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), - OPS_LOG_E(nodeName, "groupTpPtr is null."), - return ge::GRAPH_FAILED); - groupTp = std::string(groupTpPtr); - } else { - OPS_CHECK(*tpRankIdPtr != 0, - OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr), - return ge::GRAPH_FAILED); - } - OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM), - OPS_LOG_E(nodeName, - "moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.", - MOE_EXPERT_MAX_NUM, - moeExpertNum), - return ge::GRAPH_FAILED); - OPS_CHECK( - (*quantModePtr < static_cast(NO_SCALES)) || (*quantModePtr > static_cast(DYNAMIC_SCALES)), - OPS_LOG_E(nodeName, - "quantMode is invalid, only support [0, %u], but got quantMode=%ld.", - DYNAMIC_SCALES, - *quantModePtr), - return ge::GRAPH_FAILED); - - int64_t moePerRankNum = moeExpertNum / epWorldSize; - int64_t curDispatchStatusNum = moePerRankNum * epWorldSize; - OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM), - OPS_LOG_E(nodeName, - "The moe experts num must meet the conditions," - " (moeExpertNum / epWorldSize * epWorldSize <= 1280, but cur is %ld.", - curDispatchStatusNum), - return ge::GRAPH_FAILED); - - groupEp = std::string(groupEpPtr); - tilingData.moeDispatchNormalInfo.epWorldSize = static_cast(epWorldSize); - tilingData.moeDispatchNormalInfo.tpWorldSize = static_cast(*tpWorldSizePtr); - tilingData.moeDispatchNormalInfo.epRankId = static_cast(*epRankIdPtr); - tilingData.moeDispatchNormalInfo.tpRankId = static_cast(*tpRankIdPtr); - tilingData.moeDispatchNormalInfo.moeExpertNum = static_cast(moeExpertNum); - tilingData.moeDispatchNormalInfo.quantMode = static_cast(*quantModePtr); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus CheckAttrs( - gert::TilingContext *context, const char *nodeName, MoeDispatchNormalTilingData &tilingData, uint32_t &localMoeExpertNum) -{ - uint32_t epWorldSize = tilingData.moeDispatchNormalInfo.epWorldSize; - uint32_t tpWorldSize = tilingData.moeDispatchNormalInfo.tpWorldSize; - uint32_t moeExpertNum = tilingData.moeDispatchNormalInfo.moeExpertNum; - - // Validate if moe expert number can be evenly distributed across multiple machines - localMoeExpertNum = moeExpertNum / epWorldSize; - OPS_CHECK(moeExpertNum % epWorldSize != 0, - OPS_LOG_E(nodeName, - "moeExpertNum should be divisible by epWorldSize, " - "but moeExpertNum=%u, epWorldSize=%u.", - moeExpertNum, - epWorldSize), - return ge::GRAPH_FAILED); - OPS_CHECK(localMoeExpertNum <= 0, - OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum), - return ge::GRAPH_FAILED); - - // Validate input x dimension 0 and set bs - const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX); - const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); - OPS_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0), - OPS_LOG_E( - nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.", BS_UPPER_BOUND, xDim0), - return ge::GRAPH_FAILED); - tilingData.moeDispatchNormalInfo.bs = static_cast(xDim0); - - // Validate globalBS - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); - auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); - OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBsPtr is nullptr."), return ge::GRAPH_FAILED); - OPS_LOG_D(nodeName, "MoeDispatchNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n", *globalBsPtr, xDim0, epWorldSize); - OPS_CHECK(*globalBsPtr <= 0, - OPS_LOG_E(nodeName, - "globalBS is invalid, should be positive, but got globalBS=%ld.", - *globalBsPtr), - return ge::GRAPH_FAILED); - - tilingData.moeDispatchNormalInfo.globalBs = static_cast(*globalBsPtr); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, - MoeDispatchNormalTilingData &tilingData, const uint32_t quantMode, const int64_t localMoeExpertNum) -{ - uint32_t A = 0U; - uint32_t globalBs = tilingData.moeDispatchNormalInfo.globalBs; - - // Validate input x dimension 1 and set h, bs already validated - const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX); - const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); - const int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); - OPS_CHECK((xDim1 < H_MIN) || (xDim1 > H_MAX), - OPS_LOG_E(nodeName, "xShape dims1(H) should be in [%ld, %ld], but got %ld.", H_MIN, H_MAX, xDim1), - return ge::GRAPH_FAILED); // 32-byte aligned - tilingData.moeDispatchNormalInfo.h = static_cast(xDim1); - - // Validate expert_id dimensions and set k - int64_t moeExpertNum = static_cast(tilingData.moeDispatchNormalInfo.moeExpertNum); - const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX); - const int64_t expertIdsDim0 = expertIdStorageShape->GetStorageShape().GetDim(0); - const int64_t expertIdsDim1 = expertIdStorageShape->GetStorageShape().GetDim(1); - OPS_CHECK(xDim0 != expertIdsDim0, - OPS_LOG_E(nodeName, - "xShape's dim0 not equal to expertIdShape's dim0, " - "xShape's dim0 is %ld, expertIdShape's dim0 is %ld.", - xDim0, - expertIdsDim0), - return ge::GRAPH_FAILED); - OPS_CHECK((expertIdsDim1 <= 0) || (expertIdsDim1 > K_MAX) || (expertIdsDim1 > moeExpertNum), - OPS_LOG_E(nodeName, - "expertIdShape's dim1(k) should be in (0, min(%ld, moeExpertNum=%ld)], " - "but got expertIdShape's dim1=%ld.", - K_MAX, - moeExpertNum, - expertIdsDim1), - return ge::GRAPH_FAILED); - tilingData.moeDispatchNormalInfo.k = static_cast(expertIdsDim1); - - A = globalBs; - - // Validate expandX dimensions - const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX); - const int64_t expandXDim0 = expandXStorageShape->GetStorageShape().GetDim(0); - const int64_t expandXDim1 = expandXStorageShape->GetStorageShape().GetDim(1); - - OPS_CHECK(xDim1 != expandXDim1, - OPS_LOG_E(nodeName, - "expandX's dim1 not equal to xShape's dim1, " - "xShape's dim1 is %ld, expandX's dim1 is %ld.", - xDim1, - expandXDim1), - return ge::GRAPH_FAILED); - - // Validate dynamicScales dimensions - if (quantMode != NO_SCALES) { - const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX); - const int64_t dynamicScalesDim0 = dynamicScalesStorageShape->GetStorageShape().GetDim(0); - } - - // Validate assistInfo dimensions - const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX); - const int64_t assistInfoDim0 = assistInfoStorageShape->GetStorageShape().GetDim(0); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus TilingCheckMoeDispatchNormal( - gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) -{ - OPS_CHECK(!CheckTensorDim(context, nodeName, quantMode), - OPS_LOG_E(nodeName, "params shape is invalid."), - return ge::GRAPH_FAILED); - OPS_CHECK(!CheckTensorDataType(context, nodeName, quantMode), - OPS_LOG_E(nodeName, "params dataType is invalid."), - return ge::GRAPH_FAILED); - OPS_CHECK(!CheckTensorFormat(context, nodeName, quantMode), - OPS_LOG_E(nodeName, "params format is invalid."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -static void CalTilingKey(uint64_t &tilingKey, const uint32_t quantMode, const uint32_t tpWorldSize) -{ - tilingKey += static_cast(quantMode); - if (tpWorldSize == TP_WORLD_SIZE_TWO) { - tilingKey += static_cast(TILINGKEY_TP_WORLD_SIZE); - } - - return; -} - -static void SetHcommCfg(const gert::TilingContext *context, MoeDispatchNormalTilingData *tiling, const std::string groupEp, - const std::string groupTp) -{ - const char *nodeName = context->GetNodeName(); - OPS_LOG_D(nodeName, "MoeDispatchNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str()); - uint32_t opType1 = OP_TYPE_ALL_TO_ALL; - uint32_t opType2 = OP_TYPE_ALL_GATHER; - std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; - std::string algConfigAllGatherStr = "AllGather=level0:ring"; - - AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr); - mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); - mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1); - - mc2CcTilingConfig.SetGroupName(groupTp); - mc2CcTilingConfig.SetOpType(opType2); - mc2CcTilingConfig.SetAlgConfig(algConfigAllGatherStr); - mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2); -} - -static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) -{ - size_t *workSpaces = context->GetWorkspaceSizes(1); - OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); - auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - workSpaces[0] = static_cast(SYSTEM_NEED_WORKSPACE + WORKSPACE_ELEMENT_OFFSET * aivNum * aivNum); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus MoeDispatchNormalA3TilingFuncImpl(gert::TilingContext *context) -{ - const char *nodeName = context->GetNodeName(); - MoeDispatchNormalTilingData *tilingData = context->GetTilingData(); - OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); - std::string groupEp = ""; - std::string groupTp = ""; - uint32_t quantMode = NO_SCALES; - uint32_t localMoeExpertNum = 1; - OPS_LOG_I(nodeName, "Enter MoeDispatchNormal tiling check func."); - - // Get input parameter attributes - OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp, groupTp) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), - return ge::GRAPH_FAILED); - - quantMode = tilingData->moeDispatchNormalInfo.quantMode; - - // Check input/output dim, format, dataType - OPS_CHECK(TilingCheckMoeDispatchNormal(context, nodeName, quantMode) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling check param failed."), - return ge::GRAPH_FAILED); - - // Check if attribute values are valid - OPS_CHECK(CheckAttrs(context, nodeName, *tilingData, localMoeExpertNum) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Check attr failed."), - return ge::GRAPH_FAILED); - - uint32_t epRankId = tilingData->moeDispatchNormalInfo.epRankId; - - // Check shape dimensions and assign h, k - OPS_CHECK( - CheckTensorShape(context, nodeName, *tilingData, quantMode, static_cast(localMoeExpertNum)) != - ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Check tensor shape failed."), - return ge::GRAPH_FAILED); - - // Validate win area size - uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); - uint64_t h = static_cast(tilingData->moeDispatchNormalInfo.h); - uint64_t k = static_cast(tilingData->moeDispatchNormalInfo.k); - uint64_t epWorldSize = static_cast(tilingData->moeDispatchNormalInfo.epWorldSize); - uint64_t maxBs = static_cast(tilingData->moeDispatchNormalInfo.globalBs) / epWorldSize; - - // Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b) - uint64_t tokenActualLen = - ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_EXPAND_IDX_BUFFER; - uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - // Not considering dual stream size - uint64_t actualSize = maxBs * k * tokenNeedSizeDispatch * DOUBLE_DATA_BUFFER; - OPS_CHECK((actualSize > maxWindowSize), - OPS_LOG_E(nodeName, - "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu," - " localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu," - " k = %lu, NEEDED_HCCL_BUFFSIZE(maxBs * k * tokenNeedSizeDispatch) = %luMB," - " HCCL_BUFFSIZE=%luMB.", - maxBs, - h, - epWorldSize, - localMoeExpertNum, - tokenNeedSizeDispatch, - k, - actualSize / MB_SIZE + 1UL, - maxWindowSize / MB_SIZE), - return ge::GRAPH_FAILED); - tilingData->moeDispatchNormalInfo.totalWinSize = maxWindowSize; - OPS_LOG_D(nodeName, "windowSize = %lu", maxWindowSize); - - OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling set workspace failed."), - return ge::GRAPH_FAILED); - SetHcommCfg(context, tilingData, groupEp, groupTp); - uint32_t tpWorldSize = tilingData->moeDispatchNormalInfo.tpWorldSize; - uint64_t tilingKey = INIT_TILINGKEY; - CalTilingKey(tilingKey, quantMode, tpWorldSize); - OPS_LOG_D(nodeName, "tilingKey is %lu", tilingKey); - context->SetTilingKey(tilingKey); - uint32_t blockDim = 1U; - auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint64_t ubSize = 0UL; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); - blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); - context->SetBlockDim(blockDim); - context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously - tilingData->moeDispatchNormalInfo.totalUbSize = ubSize; - tilingData->moeDispatchNormalInfo.aivNum = aivNum; - OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); - PrintTilingDataInfo(nodeName, *tilingData); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus MoeDispatchNormalTilingFunc(gert::TilingContext *context) -{ - ge::graphStatus ret = MoeDispatchNormalA3TilingFuncImpl(context); - return ret; -} - -struct MoeDispatchNormalCompileInfo {}; -ge::graphStatus TilingParseForMoeDispatchNormal(gert::TilingParseContext *context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_OPTILING(MoeDispatchNormal) - .Tiling(MoeDispatchNormalTilingFunc) - .TilingParse(TilingParseForMoeDispatchNormal); -} // namespace optiling \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp b/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp deleted file mode 100644 index 0333f2d5..00000000 --- a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "kernel_operator.h" -#include "moe_dispatch_normal_tiling.h" -#include "moe_dispatch_normal.h" - -using namespace AscendC; -using namespace MoeDispatchNormalImpl; - -#define TILINGKEY_NO_QUANT 10000 -#define TILINGKEY_QUANT 10002 - -extern "C" __global__ __aicore__ void moe_dispatch_normal(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, - GM_ADDR send_token_idx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, - GM_ADDR assist_info_for_combine, GM_ADDR workspaceGM, GM_ADDR tilingGM) -{ - REGISTER_TILING_DEFAULT(MoeDispatchNormalTilingData); - TPipe pipe; -#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16) - if (TILING_KEY_IS(TILINGKEY_NO_QUANT)) { - GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM); - MoeDispatchNormal op; - op.Init(x, - expertIds, - send_offset, - send_token_idx, - recv_offset, - recv_count, - expandXOut, - dynamicScalesOut, - assist_info_for_combine, - workspaceGM, - &pipe, - &tilingData); - op.Process(); - return; - } -#elif (ORIG_DTYPE_RECV_X == DT_INT8) - if (TILING_KEY_IS(TILINGKEY_QUANT)) { - GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM); - MoeDispatchNormal op; - op.Init(x, - expertIds, - send_offset, - send_token_idx, - recv_offset, - recv_count, - expandXOut, - dynamicScalesOut, - assist_info_for_combine, - workspaceGM, - &pipe, - &tilingData); - op.Process(); - return; - } -#endif -} \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h b/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h deleted file mode 100644 index 2af4e580..00000000 --- a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h +++ /dev/null @@ -1,540 +0,0 @@ -#ifndef MOE_DISPATCH_NORMAL_H -#define MOE_DISPATCH_NORMAL_H - -#include "kernel_operator.h" -#include "kernel_tiling/kernel_tiling.h" -#include "../common/moe_distribute_base.h" -#include "moe_dispatch_normal_tiling.h" - -namespace MoeDispatchNormalImpl { -constexpr uint8_t BUFFER_NUM = 2; -constexpr uint32_t STATE_OFFSET = 32U; -constexpr uint32_t UB_ALIGN = 32U; -constexpr uint8_t COMM_NUM = 2; -constexpr uint8_t COMM_EP_IDX = 0; -constexpr uint8_t COMM_TP_IDX = 1; - -constexpr uint64_t WIN_STATE_OFFSET = 500UL * 1024UL; -constexpr uint64_t STATE_WIN_OFFSET = 950UL * 1024UL; -constexpr uint64_t WIN_ADDR_ALIGN = 512UL; -constexpr uint32_t EXPAND_IDX_INFO = 3U; -constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL; - -template -__aicore__ inline void SyncFunc() -{ - int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); - AscendC::SetFlag(eventID); - AscendC::WaitFlag(eventID); -} - -#define CamTypeClass \ - typename XType, typename ExpandXOutType, bool DynamicQuant, bool IsSmoothScaleExist, bool IsShareExpertRank - -#define CamTypeFunc XType, ExpandXOutType, DynamicQuant, IsSmoothScaleExist, IsShareExpertRank - -using namespace AscendC; -template -class MoeDispatchNormal { -public: - __aicore__ inline MoeDispatchNormal(){}; - __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx, - GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, - GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData); - __aicore__ inline void Process(); - -private: - __aicore__ inline void InputToShare(); - __aicore__ inline void SetStatus(); - __aicore__ inline void WaitStatus(); - __aicore__ inline void ShareToOutput(); - __aicore__ inline void UpdateOutput(); - __aicore__ inline void FillTriple(LocalTensor &xOutTensor, uint32_t tokenIndex, uint32_t k); - __aicore__ inline void QuantInit(); - __aicore__ inline void ReduceMaxInplace(const LocalTensor &srcLocal, uint32_t count); - __aicore__ inline void QuantProcess(); - __aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId) - { - uint32_t curRankId = ((ctxIdx == COMM_EP_IDX) ? epRankId : tpRankId); - if (curRankId == rankId) { - return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset + COMBINE_STATE_WIN_OFFSET; - } - return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) + - winDataSizeOffset + COMBINE_STATE_WIN_OFFSET; - } - - __aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) - { - uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId : tpRankId; - if (curRankId == rankId) { - return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState * WIN_STATE_OFFSET; - } - return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr)) - ->windowsExp) + - dataState * WIN_STATE_OFFSET; - } - - TPipe *tpipe_{nullptr}; - GlobalTensor xGT; - GlobalTensor expertIdsGT; - GlobalTensor sendOffsetGT; - GlobalTensor sendTokenIdxGT; - GlobalTensor recvOffsetGT; - GlobalTensor recvCountGT; - GlobalTensor dynamicScalesOutGT; - GlobalTensor expandIdxOutGT; - GlobalTensor dstGT; - GlobalTensor dstStatusGT; - - LocalTensor xInTensor; - LocalTensor xOutTensor; - LocalTensor xTmpTensor; - LocalTensor expertIdsTensor; - LocalTensor sendOffsetTensor; - LocalTensor sendTokenIdxTensor; - LocalTensor recvOffsetTensor; - LocalTensor recvCountTensor; - LocalTensor statusTensor; - - TBuf<> expertIdsBuf; - TBuf<> sendOffsetBuf; - TBuf<> sendTokenIdxBuf; - TBuf<> recvOffsetBuf; - TBuf<> recvCountBuf; - TBuf<> statusBuf; - TBuf<> waitStatusBuf; - TBuf<> gatherMaskOutBuf; - TBuf<> scalarBuf; - TBuf<> tokenCastFloatBuf; - TBuf<> tokenAbsFloatBuf; - - GM_ADDR expandXOutGM; - GM_ADDR shareGM; - - uint32_t batchSize{0}; - uint32_t globalBatchSize{0}; - uint32_t h{0}; - uint32_t topK{0}; - uint32_t blockNum{0}; - uint32_t blockIdx{0}; - uint32_t epRankSize{0}; - uint32_t epRankId{0}; - uint32_t tpRankSize{0}; - uint32_t tpRankId{0}; - uint32_t moeExpertNum{0}; - uint32_t moeExpertNumPerRank{0}; - - uint32_t hUBAlignSize{0}; - uint32_t hOutGMAlignSize{0}; - uint32_t hOutUBAlignSize{0}; - uint32_t hGMAlignCnt{0}; - uint32_t expandIdxStartIdx{0}; - uint32_t expertIdsCnt{0}; - uint32_t stateOffset{0}; - uint32_t dataState{0}; - uint32_t winDataSizeOffset{0}; - - uint32_t startStatusId; - uint32_t endStatusId; - uint32_t statusNumPerCore; - uint32_t remainStatus; - - TQueBind xQueue; - TQue xInQueue; - TQue xOutQueue; - - __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; - - DataCopyExtParams hCommuCopyOutParams; -}; - -template -__aicore__ inline void MoeDispatchNormal::Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, - GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, - GM_ADDR expandIdxOut, GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData) -{ - tpipe_ = pipe; - blockIdx = GetBlockIdx(); - - winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); - winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>(); - - GlobalTensor selfDataStatusTensor; - GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp); - selfDataStatusTensor.SetGlobalBuffer( - (__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET + blockIdx * WIN_ADDR_ALIGN)); - - batchSize = tilingData->moeDispatchNormalInfo.bs; - globalBatchSize = tilingData->moeDispatchNormalInfo.globalBs; - h = tilingData->moeDispatchNormalInfo.h; - topK = tilingData->moeDispatchNormalInfo.k; - blockNum = tilingData->moeDispatchNormalInfo.aivNum; - epRankSize = tilingData->moeDispatchNormalInfo.epWorldSize; - epRankId = tilingData->moeDispatchNormalInfo.epRankId; - moeExpertNum = tilingData->moeDispatchNormalInfo.moeExpertNum; - moeExpertNumPerRank = moeExpertNum / epRankSize; - - xGT.SetGlobalBuffer((__gm__ XType *)x); - expertIdsGT.SetGlobalBuffer((__gm__ int32_t *)expertIds); - sendOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(send_offset)); - sendTokenIdxGT.SetGlobalBuffer((__gm__ int32_t *)(send_tokenIdx)); - recvOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(recv_offset)); - recvCountGT.SetGlobalBuffer((__gm__ int32_t *)(recv_count)); - dynamicScalesOutGT.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); - expandIdxOutGT.SetGlobalBuffer((__gm__ int32_t *)(expandIdxOut)); - - expandXOutGM = expandXOut; - - hUBAlignSize = Ceil(h * sizeof(ExpandXOutType), UB_ALIGN) * UB_ALIGN; - uint32_t hScaleSizeAlign = hUBAlignSize + UB_ALIGN; - expandIdxStartIdx = hScaleSizeAlign / sizeof(int32_t); - - uint32_t hScaleIdxSize = hScaleSizeAlign + EXPAND_IDX_INFO * sizeof(int32_t); - hOutGMAlignSize = Ceil(hScaleIdxSize, WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - hGMAlignCnt = hOutGMAlignSize / sizeof(ExpandXOutType); - - expertIdsCnt = batchSize * topK; - statusNumPerCore = moeExpertNum / blockNum; - remainStatus = moeExpertNum % blockNum; - startStatusId = statusNumPerCore * blockIdx; - if (blockIdx < remainStatus) { - statusNumPerCore += 1; - startStatusId += blockIdx; - } else { - startStatusId += remainStatus; - } - endStatusId = startStatusId + statusNumPerCore; - stateOffset = STATE_OFFSET; - DataCacheCleanAndInvalid(selfDataStatusTensor); - dataState = selfDataStatusTensor(0); - if (dataState == 0) { - selfDataStatusTensor(0) = 1; - } else { - selfDataStatusTensor(0) = 0; - } - DataCacheCleanAndInvalid(selfDataStatusTensor); - PipeBarrier(); - - uint64_t hSizeAlignCombine = Ceil(h * sizeof(XType), WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - winDataSizeOffset = dataState * (tilingData->moeDispatchNormalInfo.totalWinSize / 2) + - globalBatchSize / epRankSize * topK * hSizeAlignCombine; - shareGM = GetWindAddrByRankId(COMM_EP_IDX, epRankId); - - hOutUBAlignSize = Ceil(hScaleIdxSize, UB_ALIGN) * UB_ALIGN; - if constexpr (DynamicQuant) { - QuantInit(); - } else { - tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 2 * 14K = 28K - } - - tpipe_->InitBuffer(sendOffsetBuf, moeExpertNum * sizeof(int32_t)); // 4 * moeNum - sendOffsetTensor = sendOffsetBuf.Get(); - - hCommuCopyOutParams = {1U, static_cast(hScaleIdxSize), 0U, 0U, 0U}; -} - -template -__aicore__ inline void MoeDispatchNormal::QuantInit() -{ - uint32_t hAlignSize = Ceil(h * sizeof(XType), UB_ALIGN) * UB_ALIGN; - tpipe_->InitBuffer(xInQueue, BUFFER_NUM, hAlignSize); // 14K * 2 - tpipe_->InitBuffer(xOutQueue, BUFFER_NUM, hOutUBAlignSize); // 7K * 2 - - tpipe_->InitBuffer(tokenCastFloatBuf, h * sizeof(float)); // 28K - tpipe_->InitBuffer(tokenAbsFloatBuf, h * sizeof(float)); // 28K -} - -template -__aicore__ inline void MoeDispatchNormal::ReduceMaxInplace( - const LocalTensor &srcLocal, uint32_t count) -{ - uint64_t repsFp32 = count >> 6; // 6 is count / elemPerRefFp32 - uint64_t offsetsFp32 = repsFp32 << 6; // 6 is repsFp32 * elemPerRefFp32 - uint64_t remsFp32 = count & 0x3f; // 0x3f 63, count % elemPerRefFp32 - const uint64_t elemPerRefFp32 = 64UL; // 256 bit / sizeof(float) - if (likely(repsFp32 > 1)) { - // 8 is rep stride - Max(srcLocal, srcLocal[elemPerRefFp32], srcLocal, elemPerRefFp32, repsFp32 - 1, {1, 1, 1, 0, 8, 0}); - PipeBarrier(); - } - if (unlikely(remsFp32 > 0) && unlikely(offsetsFp32 > 0)) { - Max(srcLocal, srcLocal[offsetsFp32], srcLocal, remsFp32, 1, {1, 1, 1, 0, 8, 0}); - PipeBarrier(); - } - uint32_t mask = (repsFp32 > 0) ? elemPerRefFp32 : count; - // 8 is rep stride - WholeReduceMax(srcLocal, srcLocal, mask, 1, 8, 1, 8); -} - -template -__aicore__ inline void MoeDispatchNormal::QuantProcess() -{ - float dynamicScale = 0.0; - LocalTensor floatLocalTemp; - floatLocalTemp = tokenCastFloatBuf.Get(); - - Cast(floatLocalTemp, xInTensor, RoundMode::CAST_NONE, h); - xInQueue.FreeTensor(xInTensor); - PipeBarrier(); - - if constexpr (DynamicQuant) { - LocalTensor floatLocalAbsTemp = tokenAbsFloatBuf.Get(); - - Abs(floatLocalAbsTemp, floatLocalTemp, h); - PipeBarrier(); - ReduceMaxInplace(floatLocalAbsTemp, h); - - SyncFunc(); - dynamicScale = float(127.0) / (floatLocalAbsTemp.GetValue(0) + 1e-12f); - SyncFunc(); - Muls(floatLocalTemp, floatLocalTemp, dynamicScale, h); - PipeBarrier(); - } - LocalTensor halfLocalTemp = floatLocalTemp.ReinterpretCast(); - LocalTensor int32LocalTemp = floatLocalTemp.ReinterpretCast(); - Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, h); - PipeBarrier(); - SetDeqScale((half)1.000000e+00f); - PipeBarrier(); - - Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, h); - - PipeBarrier(); - Cast(xOutTensor, halfLocalTemp, RoundMode::CAST_TRUNC, h); - - floatLocalTemp = xOutTensor.template ReinterpretCast(); - floatLocalTemp.SetValue(hUBAlignSize / sizeof(float), float(1.0) / dynamicScale); // int8->float32 -} - -template -__aicore__ inline void MoeDispatchNormal::FillTriple( - LocalTensor &xOutTensor, uint32_t tokenIndex, uint32_t k) -{ - SyncFunc(); - LocalTensor xOutTint32 = xOutTensor.template ReinterpretCast(); - xOutTint32(expandIdxStartIdx) = epRankId; - xOutTint32(expandIdxStartIdx + 1) = tokenIndex; - xOutTint32(expandIdxStartIdx + 2) = k; - SyncFunc(); -} - -template -__aicore__ inline void MoeDispatchNormal::InputToShare() -{ - DataCopyExtParams sendOffsetParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; - DataCopyPadExtParams sendOffsetCopyPadParams{false, 0U, 0U, 0U}; - DataCopyPad(sendOffsetTensor, sendOffsetGT, sendOffsetParams, sendOffsetCopyPadParams); - SyncFunc(); - - uint32_t startTokenId, endTokenId, sendTokenNum, remainTokenNum; - sendTokenNum = expertIdsCnt / blockNum; - remainTokenNum = expertIdsCnt % blockNum; - startTokenId = sendTokenNum * blockIdx; - if (blockIdx < remainTokenNum) { - sendTokenNum += 1; - startTokenId += blockIdx; - } else { - startTokenId += remainTokenNum; - } - endTokenId = startTokenId + sendTokenNum; - - if (startTokenId >= expertIdsCnt) { - return; - } - tpipe_->InitBuffer(expertIdsBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48 - tpipe_->InitBuffer(sendTokenIdxBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48 - expertIdsTensor = expertIdsBuf.Get(); - sendTokenIdxTensor = sendTokenIdxBuf.Get(); - DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U}; - DataCopyExtParams sendTokenIdxParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U}; - DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPadExtParams tokenCopyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(expertIdsTensor, expertIdsGT[startTokenId], expertIdsCntParams, copyPadExtParams); - DataCopyPad(sendTokenIdxTensor, sendTokenIdxGT[startTokenId], sendTokenIdxParams, copyPadExtParams); - SyncFunc(); - - DataCopyExtParams xCopyParams = {1U, static_cast(h * sizeof(XType)), 0U, 0U, 0U}; - for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { - uint32_t dstExpertId = expertIdsTensor(tokenIndex - startTokenId); - int32_t curExpertCnt = sendTokenIdxTensor(tokenIndex - startTokenId); - int32_t dstExpertOffset = sendOffsetTensor(dstExpertId); - GM_ADDR rankGM = - (__gm__ uint8_t *)(shareGM + hOutGMAlignSize * (dstExpertOffset + curExpertCnt)); - dstGT.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM); - - if constexpr (DynamicQuant) { - xInTensor = xInQueue.AllocTensor(); - DataCopyPad(xInTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams); - xInQueue.EnQue(xInTensor); - xInTensor = xInQueue.DeQue(); - xOutTensor = xOutQueue.AllocTensor(); - QuantProcess(); - xOutQueue.EnQue(xOutTensor); - xOutTensor = xOutQueue.DeQue(); - FillTriple(xOutTensor, tokenIndex / topK, tokenIndex % topK); - DataCopyPad(dstGT, xOutTensor, hCommuCopyOutParams); - xOutQueue.FreeTensor(xOutTensor); - } else { - xTmpTensor = xQueue.AllocTensor(); - DataCopyPad(xTmpTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams); - xQueue.EnQue(xTmpTensor); - xTmpTensor = xQueue.DeQue(); - FillTriple(xTmpTensor, tokenIndex / topK, tokenIndex % topK); - DataCopyPad(dstGT, xTmpTensor, hCommuCopyOutParams); - xQueue.FreeTensor(xTmpTensor); - } - } -} - -template -__aicore__ inline void MoeDispatchNormal::SetStatus() -{ - uint32_t startExpId, endExpId, expNumPerCore; - expNumPerCore = statusNumPerCore; - startExpId = startStatusId; - endExpId = endStatusId; - if (startExpId > moeExpertNum) { - SyncAll(); - return; - } - uint32_t statusCntAlign = Ceil(expNumPerCore, 8) * 8; - tpipe_->InitBuffer(statusBuf, statusCntAlign * UB_ALIGN); // moeNum / 48 * 32 - statusTensor = statusBuf.Get(); - Duplicate(statusTensor, 0, expNumPerCore * 8); - uint64_t mask[2] = {0x101010101010101, 0}; - PipeBarrier(); - Duplicate(statusTensor, 0x3F800000, mask, statusCntAlign / 8, 1, 8); - PipeBarrier(); - SyncAll(); - for (uint32_t i = startExpId; i < endExpId; ++i) { - uint32_t targetRankId = i / moeExpertNumPerRank; - uint32_t offset = stateOffset * (epRankId + i % moeExpertNumPerRank * epRankSize); - GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, targetRankId) + offset); - dstStatusGT.SetGlobalBuffer((__gm__ int32_t *)rankGM); - DataCopy(dstStatusGT, statusTensor[(i - startExpId) * 8], 8UL); - } - SyncFunc(); -} - -template -__aicore__ inline void MoeDispatchNormal::WaitStatus() -{ - tpipe_->Reset(); - uint32_t waitStatusBufSize = (((statusNumPerCore * UB_ALIGN) > 256) ? (statusNumPerCore * UB_ALIGN) : 256); - tpipe_->InitBuffer(waitStatusBuf, waitStatusBufSize); // moeNum /48 * 32B = 43 * 32B - tpipe_->InitBuffer(gatherMaskOutBuf, moeExpertNum * sizeof(float)); // moeNum * 4B - tpipe_->InitBuffer(scalarBuf, UB_ALIGN * 3); // 96B - tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 28K - tpipe_->InitBuffer(recvOffsetBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B - tpipe_->InitBuffer(recvCountBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B - - recvOffsetTensor = recvOffsetBuf.Get(); - recvCountTensor = recvCountBuf.Get(); - DataCopyExtParams recvOffsetParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; - DataCopyExtParams recvCountParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; - DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(recvOffsetTensor, recvOffsetGT, recvOffsetParams, copyPadExtParams); - DataCopyPad(recvCountTensor, recvCountGT, recvCountParams, copyPadExtParams); - - if (startStatusId >= moeExpertNum) { - SyncAll(); - return; - } - - LocalTensor gatherMaskOutTensor = gatherMaskOutBuf.Get(); - LocalTensor statusSumOutTensor = scalarBuf.GetWithOffset(UB_ALIGN / sizeof(float), UB_ALIGN); - LocalTensor statusFp32Tensor = waitStatusBuf.Get(); - GlobalTensor windowInstatusFp32Tensor; - windowInstatusFp32Tensor.SetGlobalBuffer((__gm__ float *)(GetWindStateAddrByRankId(COMM_EP_IDX, epRankId))); - uint32_t mask = 1; - float compareTarget = static_cast(1.0) * statusNumPerCore; - float sumOfFlag = static_cast(-1.0); - DataCopyParams intriParams{static_cast(statusNumPerCore), 1, 0, 0}; - SyncFunc(); - while (sumOfFlag != compareTarget) { - DataCopy(statusFp32Tensor, windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)], intriParams); - SyncFunc(); - ReduceSum(statusSumOutTensor, statusFp32Tensor, gatherMaskOutTensor, mask, statusNumPerCore, 1); - SyncFunc(); - sumOfFlag = statusSumOutTensor.GetValue(0); - } - - // Clear state - SyncFunc(); - DataCopyParams intriOutParams{static_cast(statusNumPerCore), 1, 0, 0}; - uint64_t duplicateMask[2] = {0x101010101010101, 0}; - LocalTensor cleanStateTensor = waitStatusBuf.Get(); - SyncFunc(); - Duplicate(cleanStateTensor, 0, duplicateMask, Ceil(statusNumPerCore, 8), 1, 8); - SyncFunc(); - DataCopy(windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)], - cleanStateTensor.ReinterpretCast(), - intriOutParams); - SyncFunc(); - SyncAll(); -} - -template -__aicore__ inline void MoeDispatchNormal::ShareToOutput() -{ - if (startStatusId >= moeExpertNum) { - return; - } - uint32_t fromRank, count, preCount, recvOffset, targetOffset; - DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyExtParams dataCopyExandIdxParams{1U, sizeof(int32_t) * EXPAND_IDX_INFO, 0U, 0U, 0U}; - DataCopyExtParams dataCopyOutParams{1U, static_cast(statusNumPerCore * sizeof(int32_t)), 0U, 0U, 0U}; - DataCopyExtParams expandXCopyParams = {1U, static_cast(h * sizeof(ExpandXOutType)), 0U, 0U, 0U}; - LocalTensor xTmpTensorInt; - AscendC::TQueSync recvCountLocalSync; - recvCountLocalSync.SetFlag(0); - recvCountLocalSync.WaitFlag(0); - for (uint32_t i = startStatusId; i < endStatusId; ++i) { - preCount = 0; - if (likely(i != 0)) { - preCount = recvCountTensor(i - 1); - } - fromRank = i % epRankSize; - count = recvCountTensor(i) - preCount; - recvOffset = recvOffsetTensor(i); - targetOffset = preCount; - GM_ADDR recvStart = - (__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, fromRank)) + recvOffset * hOutGMAlignSize; - GlobalTensor srcTokenGT, dstTokenGT; - for (uint32_t j = 0; j < count; ++j) { - srcTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(recvStart + j * hOutGMAlignSize)); - xTmpTensor = xQueue.AllocTensor(); - DataCopyPad(xTmpTensor, srcTokenGT, hCommuCopyOutParams, copyPadExtParams); - xQueue.EnQue(xTmpTensor); - xTmpTensor = xQueue.DeQue(); - xTmpTensorInt = xTmpTensor.template ReinterpretCast(); - DataCopyPad(expandIdxOutGT[(targetOffset + j) * EXPAND_IDX_INFO], - xTmpTensorInt[expandIdxStartIdx], - dataCopyExandIdxParams); - if constexpr (DynamicQuant) { - DataCopyExtParams floatDataCopyParams = {1U, sizeof(float), 0U, 0U, 0U}; - LocalTensor xOutFp32Tensor = xTmpTensor.template ReinterpretCast(); - DataCopyPad(dynamicScalesOutGT[targetOffset + j], - xOutFp32Tensor[hUBAlignSize / sizeof(float)], - floatDataCopyParams); - } - dstTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM) + (targetOffset + j) * h, h); - DataCopyPad(dstTokenGT, xTmpTensor, expandXCopyParams); - xQueue.FreeTensor(xTmpTensor); - } - } -} - -template -__aicore__ inline void MoeDispatchNormal::Process() -{ - if ASCEND_IS_AIV { - InputToShare(); - SetStatus(); - WaitStatus(); - ShareToOutput(); - } -} - -} // namespace MoeDispatchNormalImpl -#endif \ No newline at end of file diff --git a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal_tiling.h b/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal_tiling.h deleted file mode 100644 index 11fd1255..00000000 --- a/csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal_tiling.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef MOE_DISPATCH_NORMAL_TILING_H -#define MOE_DISPATCH_NORMAL_TILING_H - -struct MoeDispatchNormalInfo { - uint32_t epWorldSize; // epWorldSize - uint32_t tpWorldSize; // tpWorldSize - uint32_t epRankId; // epRankId - uint32_t tpRankId; // tpRankId - uint32_t moeExpertNum; // moe expert number - uint32_t quantMode; // quant mode - uint32_t globalBs; // globalBs = BS * worldSize - uint32_t bs; // bs - uint32_t k; // k - uint32_t h; // h - uint32_t aivNum; // aivNum - bool isQuant; // whether quant or not - bool reserved2; // reserved - bool reserved3; // reserved - uint64_t totalUbSize; // epWorldSize - uint64_t totalWinSize; -}; - -struct MoeDispatchNormalTilingData { - Mc2InitTiling mc2InitTiling; - Mc2CcTiling mc2CcTiling1; - Mc2CcTiling mc2CcTiling2; - MoeDispatchNormalInfo moeDispatchNormalInfo; -}; - -#endif \ No newline at end of file diff --git a/csrc/notify_dispatch/op_host/CMakeLists.txt b/csrc/notify_dispatch/op_host/CMakeLists.txt deleted file mode 100644 index 990115ce..00000000 --- a/csrc/notify_dispatch/op_host/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_ops_compile_options( - OP_NAME NotifyDispatch - OPTIONS --cce-auto-sync=off - -Wno-deprecated-declarations - -Werror -) - -target_sources(op_host_aclnnInner PRIVATE - notify_dispatch.cpp -) - -target_sources(opapi PRIVATE - aclnn_notify_dispatch.cpp -) - -if (NOT BUILD_OPEN_PROJECT) - target_sources(aclnn_ops_train PRIVATE - aclnn_notify_dispatch.cpp - ) - - target_sources(aclnn_ops_infer PRIVATE - aclnn_notify_dispatch.cpp - ) -endif () - -target_sources(optiling PRIVATE - notify_dispatch_tiling.cpp -) - -target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_sources(opsproto PRIVATE) - -file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_notify_dispatch.h") - -install(FILES ${_GMM_Aclnn_header} - DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL -) diff --git a/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp b/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp deleted file mode 100644 index e808798a..00000000 --- a/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include -#include "graph/types.h" -#include "aclnn_notify_dispatch.h" - -extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr); - -#ifdef __cplusplus -extern "C" { -#endif - -enum NnopbaseHcclServerType { - NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, - NNOPBASE_HCCL_SERVER_TYPE_MTE, - NNOPBASE_HCCL_SERVER_TYPE_END -}; -extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); - -extern aclnnStatus aclnnInnerNotifyDispatchGetWorkspaceSize( - const aclTensor *sendData, - const aclTensor *tokenPerExpertData, - int64_t sendCount, - int64_t numTokens, - char *commGroup, - int64_t rankSize, - int64_t rankId, - int64_t localRankSize, - int64_t localRankId, - const aclTensor *sendDataOffset, - const aclTensor *recvData, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -extern aclnnStatus aclnnInnerNotifyDispatch( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -aclnnStatus aclnnNotifyDispatchGetWorkspaceSize( - const aclTensor *sendData, - const aclTensor *tokenPerExpertData, - int64_t sendCount, - int64_t numTokens, - char *commGroup, - int64_t rankSize, - int64_t rankId, - int64_t localRankSize, - int64_t localRankId, - const aclTensor *sendDataOffset, - const aclTensor *recvData, - uint64_t *workspaceSize, - aclOpExecutor **executor) -{ - return aclnnInnerNotifyDispatchGetWorkspaceSize( - sendData, - tokenPerExpertData, - sendCount, - numTokens, - commGroup, - rankSize, - rankId, - localRankSize, - localRankId, - sendDataOffset, - recvData, - workspaceSize, - executor); -} - -aclnnStatus aclnnNotifyDispatch( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream) -{ - if (NnopbaseSetHcclServerType) { - NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); - } - return aclnnInnerNotifyDispatch(workspace, workspaceSize, executor, stream); -} - -#ifdef __cplusplus -} -#endif diff --git a/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h b/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h deleted file mode 100644 index be9ae04f..00000000 --- a/csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h +++ /dev/null @@ -1,61 +0,0 @@ - -#ifndef ACLNN_NOTIFY_DISPATCH_H_ -#define ACLNN_NOTIFY_DISPATCH_H_ - -#include "aclnn/acl_meta.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/* funtion: aclnnNotifyDispatchGetWorkspaceSize - * parameters : - * sendData : required - * tokenPerExpertData : required - * sendCount : required - * numTokens : required - * commGroup : required - * rankSize : required - * rankId : required - * localRankSize : required - * localRankId : required - * sendDataOffset : required - * recvData : required - * workspaceSize : size of workspace(output). - * executor : executor context(output). - */ -__attribute__((visibility("default"))) -aclnnStatus aclnnNotifyDispatchGetWorkspaceSize( - const aclTensor *sendData, - const aclTensor *tokenPerExpertData, - int64_t sendCount, - int64_t numTokens, - char *commGroup, - int64_t rankSize, - int64_t rankId, - int64_t localRankSize, - int64_t localRankId, - const aclTensor *sendDataOffset, - const aclTensor *recvData, - uint64_t *workspaceSize, - aclOpExecutor **executor); - -/* funtion: aclnnNotifyDispatch - * parameters : - * workspace : workspace memory addr(input). - * workspaceSize : size of workspace(input). - * executor : executor context(input). - * stream : acl stream. - */ -__attribute__((visibility("default"))) -aclnnStatus aclnnNotifyDispatch( - void *workspace, - uint64_t workspaceSize, - aclOpExecutor *executor, - aclrtStream stream); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/csrc/notify_dispatch/op_host/notify_dispatch.cpp b/csrc/notify_dispatch/op_host/notify_dispatch.cpp deleted file mode 100644 index 33999266..00000000 --- a/csrc/notify_dispatch/op_host/notify_dispatch.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "register/op_def_registry.h" - -namespace ops { -class NotifyDispatch : public OpDef { -public: - explicit NotifyDispatch(const char *name) : OpDef(name) - { - this->Input("sendData") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("tokenPerExpertData") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("sendDataOffset") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("recvData") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Attr("sendCount").Int(); - this->Attr("num_tokens").Int(); - this->Attr("comm_group").String(); - this->Attr("rank_size").Int(); - this->Attr("rank_id").Int(); - this->Attr("local_rank_size").Int(); - this->Attr("local_rank_id").Int(); - - OpAICoreConfig aicore_config_base; - aicore_config_base.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); - - OpAICoreConfig aicore_config_A2 = aicore_config_base; - aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false"); - - OpAICoreConfig aicore_config = aicore_config_base; - aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true"); - - this->AICore().AddConfig("ascend910_93", aicore_config); - this->AICore().AddConfig("ascend910b", aicore_config_A2); - this->MC2().HcclGroup("comm_group"); - } -}; - -OP_ADD(NotifyDispatch); -} // namespace ops diff --git a/csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp b/csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp deleted file mode 100644 index 65041dcd..00000000 --- a/csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp +++ /dev/null @@ -1,306 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "log/ops_log.h" -#include "graph/utils/type_utils.h" -#include "register/op_def_registry.h" -#include "../op_kernel/notify_dispatch_tiling.h" -#include "tiling/platform/platform_ascendc.h" -#include "tiling/hccl/hccl_tiling.h" -#include "experiment/platform/platform/platform_infos_def.h" - -using namespace ge; -namespace { -class Mc2TilingUtils { -public: -#define HCCL_BUFFSIZE "HCCL_BUFFSIZE" - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - if (getenv(HCCL_BUFFSIZE) == nullptr) { - OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(HCCL_BUFFSIZE)); - defaultWindowSize = std::stoi(envStr); - } catch (const std::invalid_argument &ia) { - OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); - } catch (const std::out_of_range &oor) { - OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } -}; -constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll - -constexpr uint32_t INPUT_SEND_DATA_INDEX = 0; -constexpr uint32_t INPUT_TOKEN_PER_EXPERT_INDEX = 1; - -constexpr uint32_t OUTPUT_SEND_DATA_OFFSET_INDEX = 0; -constexpr uint32_t OUTPUT_RECV_DATA_INDEX = 1; - -constexpr uint32_t ATTR_SEND_COUNT_INDEX = 0; -constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 1; -constexpr uint32_t ATTR_COMM_GROUP_INDEX = 2; -constexpr uint32_t ATTR_RANK_SIZE_INDEX = 3; -constexpr uint32_t ATTR_RANK_ID_INDEX = 4; -constexpr uint32_t ATTR_LOCAL_RANK_SIZE_INDEX = 5; -constexpr uint32_t ATTR_LOCAL_RANK_ID_INDEX = 6; - -const size_t MAX_GROUP_NAME_LENGTH = 128UL; -const int64_t MAX_COMM_WORLD_SIZE = 384; - -constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; -constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024; -constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024; -constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes -constexpr uint64_t MB_SIZE = 1024UL * 1024UL; - -constexpr static int TILING_KEY_FLOAT16 = 20; -constexpr static int TILING_KEY_BFLOAT16 = 21; -constexpr static int TILING_KEY_FLOAT = 22; -constexpr static int TILING_KEY_INT = 23; -constexpr static int TILING_KEY_A2_TYPE = 100; - -constexpr static int ALL_TO_ALL_CORE_NUM = 32; -} // namespace - -namespace optiling { -static void PrintTilingDataInfo(const char *nodeName, NotifyDispatchTilingData &tilingData) -{ - OPS_LOG_D(nodeName, "rankSize is %u.", tilingData.notifyDispatchInfo.rankSize); - OPS_LOG_D(nodeName, "rankId is %u.", tilingData.notifyDispatchInfo.rankId); - OPS_LOG_D(nodeName, "localRankSize is %u.", tilingData.notifyDispatchInfo.localRankSize); - OPS_LOG_D(nodeName, "localRankId is %u.", tilingData.notifyDispatchInfo.localRankId); - OPS_LOG_D(nodeName, "sendCount is %u.", tilingData.notifyDispatchInfo.sendCount); - OPS_LOG_D(nodeName, "numTokens is %u.", tilingData.notifyDispatchInfo.numTokens); - OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.notifyDispatchInfo.aivNum); - OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.notifyDispatchInfo.totalUbSize); -} - -static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, - NotifyDispatchTilingData &tilingData, std::string &commGroup) -{ - auto attrs = context->GetAttrs(); - OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); - - auto sendCountPtr = attrs->GetAttrPointer(ATTR_SEND_COUNT_INDEX); - auto numTokenPtr = attrs->GetAttrPointer(ATTR_NUM_TOKENS_INDEX); - auto commGroupPtr = attrs->GetAttrPointer(static_cast(ATTR_COMM_GROUP_INDEX)); - auto rankSizePtr = attrs->GetAttrPointer(ATTR_RANK_SIZE_INDEX); - auto rankIdPtr = attrs->GetAttrPointer(ATTR_RANK_ID_INDEX); - auto localRankSizePtr = attrs->GetAttrPointer(ATTR_LOCAL_RANK_SIZE_INDEX); - auto localRankIdPtr = attrs->GetAttrPointer(ATTR_LOCAL_RANK_ID_INDEX); - - OPS_CHECK((commGroupPtr == nullptr) || (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == 0) || - (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), - OPS_LOG_E(nodeName, "commGroupPtr is null."), - return ge::GRAPH_FAILED); - OPS_CHECK(sendCountPtr == nullptr, OPS_LOG_E(nodeName, "sendCountPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(numTokenPtr == nullptr, OPS_LOG_E(nodeName, "numTokenPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(rankSizePtr == nullptr, OPS_LOG_E(nodeName, "rankSizePtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(rankIdPtr == nullptr, OPS_LOG_E(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK( - localRankSizePtr == nullptr, OPS_LOG_E(nodeName, "localRankSizePtr is null."), return ge::GRAPH_FAILED); - OPS_CHECK(localRankIdPtr == nullptr, OPS_LOG_E(nodeName, "localRankIdPtr is null."), return ge::GRAPH_FAILED); - - OPS_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_COMM_WORLD_SIZE), - OPS_LOG_E(nodeName, - "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", - MAX_COMM_WORLD_SIZE, - *rankSizePtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*rankIdPtr < 0) || (*rankIdPtr >= *rankSizePtr), - OPS_LOG_E(nodeName, "rankId is invalid, only support [0, %ld), but got rankId=%ld.", *rankSizePtr, *rankIdPtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*sendCountPtr <= 0), - OPS_LOG_E(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%ld.", *sendCountPtr), - return ge::GRAPH_FAILED); - OPS_CHECK((*numTokenPtr <= 0), - OPS_LOG_E(nodeName, "numTokenPtr is invalid, only support > 0, but got numTokenPtr=%ld.", *numTokenPtr), - return ge::GRAPH_FAILED); - - commGroup = std::string(commGroupPtr); - tilingData.notifyDispatchInfo.rankSize = static_cast(*rankSizePtr); - tilingData.notifyDispatchInfo.rankId = static_cast(*rankIdPtr); - tilingData.notifyDispatchInfo.localRankSize = static_cast(*localRankSizePtr); - tilingData.notifyDispatchInfo.localRankId = static_cast(*localRankIdPtr); - tilingData.notifyDispatchInfo.sendCount = static_cast(*sendCountPtr); - tilingData.notifyDispatchInfo.numTokens = static_cast(*numTokenPtr); - - return ge::GRAPH_SUCCESS; -} - -static void SetHcommCfg(const gert::TilingContext *context, - NotifyDispatchTilingData *tiling, const std::string commGroup) -{ - const char *nodeName = context->GetNodeName(); - OPS_LOG_D(nodeName, "NotifyDispatch commGroup = %s", commGroup.c_str()); - uint32_t opType1 = OP_TYPE_ALL_TO_ALL; - std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; - - AscendC::Mc2CcTilingConfig mc2CcTilingConfig(commGroup, opType1, algConfigAllToAllStr); - mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); - mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1); -} - -static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) -{ - size_t *workSpaces = context->GetWorkspaceSizes(1); - OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); - workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE; - return ge::GRAPH_SUCCESS; -} - -static bool CheckTensorDataType( - gert::TilingContext *context, const char *nodeName) -{ - auto sendData = context->GetInputDesc(INPUT_SEND_DATA_INDEX); - OPS_CHECK(sendData == nullptr, OPS_LOG_E(nodeName, "sendData is null."), return false); - OPS_CHECK((sendData->GetDataType() != ge::DT_BF16) && (sendData->GetDataType() != ge::DT_FLOAT16) && - (sendData->GetDataType() != ge::DT_FLOAT) && (sendData->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, - "sendData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", - static_cast(sendData->GetDataType())), - return false); - uint64_t dataSize; - if ((sendData->GetDataType() == ge::DT_BF16) || (sendData->GetDataType() == ge::DT_FLOAT16)) { - dataSize = 2; - } else { - dataSize = 4; - } - auto tokenPerExpertData = context->GetInputDesc(INPUT_TOKEN_PER_EXPERT_INDEX); - OPS_CHECK(tokenPerExpertData == nullptr, OPS_LOG_E(nodeName, "tokenPerExpertData is null."), return false); - OPS_CHECK((tokenPerExpertData->GetDataType() != ge::DT_BF16) && (tokenPerExpertData->GetDataType() != ge::DT_FLOAT16) && - (tokenPerExpertData->GetDataType() != ge::DT_FLOAT) && (tokenPerExpertData->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, - "tokenPerExpertData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", - static_cast(tokenPerExpertData->GetDataType())), - return false); - - auto sendDataOffset = context->GetInputDesc(OUTPUT_SEND_DATA_OFFSET_INDEX); - OPS_CHECK(sendDataOffset == nullptr, OPS_LOG_E(nodeName, "sendDataOffset is null."), return false); - OPS_CHECK((sendDataOffset->GetDataType() != ge::DT_BF16) && (sendDataOffset->GetDataType() != ge::DT_FLOAT16) && - (sendDataOffset->GetDataType() != ge::DT_FLOAT) && (sendDataOffset->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, - "sendDataOffset datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", - static_cast(sendDataOffset->GetDataType())), - return false); - - auto recvData = context->GetInputDesc(OUTPUT_RECV_DATA_INDEX); - OPS_CHECK(recvData == nullptr, OPS_LOG_E(nodeName, "recvData is null."), return false); - OPS_CHECK((recvData->GetDataType() != ge::DT_BF16) && (recvData->GetDataType() != ge::DT_FLOAT16) && - (recvData->GetDataType() != ge::DT_FLOAT) && (recvData->GetDataType() != ge::DT_INT32), - OPS_LOG_E(nodeName, - "recvData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", - static_cast(recvData->GetDataType())), - return false); - - // Verify the size of the win area - NotifyDispatchTilingData *tilingData = context->GetTilingData(); - uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); - uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount; - if (actualSize > maxWindowSize) { - OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize); - return false; - } - return true; -} - -static ge::graphStatus TilingCheckTensor( - gert::TilingContext *context, const char *nodeName) -{ - OPS_CHECK(!CheckTensorDataType(context, nodeName), - OPS_LOG_E(nodeName, "params dataType is invalid."), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus NotifyDispatchTilingFuncImpl(gert::TilingContext *context) -{ - const char *nodeName = context->GetNodeName(); - NotifyDispatchTilingData *tilingData = context->GetTilingData(); - OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); - std::string commGroup = ""; - OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func."); - - OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, commGroup) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), - return ge::GRAPH_FAILED); - - OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling check param failed."), - return ge::GRAPH_FAILED); - - OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling set workspace failed."), - return ge::GRAPH_FAILED); - SetHcommCfg(context, tilingData, commGroup); - - int tilingKey = TILING_KEY_INT; - auto sendDtype = context->GetInputDesc(0)->GetDataType(); - if (sendDtype == ge::DT_FLOAT16) { - tilingKey = TILING_KEY_FLOAT16; - } else if (sendDtype == ge::DT_BF16) { - tilingKey = TILING_KEY_BFLOAT16; - } else if (sendDtype == ge::DT_FLOAT) { - tilingKey = TILING_KEY_FLOAT; - } - - fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); - fe::PlatFormInfos &platformInfo = *platformInfoPtr; - - std::string socVersion; - (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); - - if (socVersion == "Ascend910B") { - tilingKey = tilingKey + TILING_KEY_A2_TYPE; - } - context->SetTilingKey(tilingKey); - - auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - uint32_t blockDim; - uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); - uint64_t ubSize = 0UL; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); - - blockDim = aivNum; - context->SetBlockDim(blockDim); - tilingData->notifyDispatchInfo.totalUbSize = ubSize; - tilingData->notifyDispatchInfo.aivNum = aivNum; - OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); - PrintTilingDataInfo(nodeName, *tilingData); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus NotifyDispatchTilingFunc(gert::TilingContext *context) -{ - ge::graphStatus ret = NotifyDispatchTilingFuncImpl(context); - return ret; -} - -struct NotifyDispatchCompileInfo {}; -ge::graphStatus TilingParseForNotifyDispatch(gert::TilingParseContext *context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_OPTILING(NotifyDispatch) - .Tiling(NotifyDispatchTilingFunc) - .TilingParse(TilingParseForNotifyDispatch); -} // namespace optiling \ No newline at end of file diff --git a/csrc/notify_dispatch/op_kernel/notify_dispatch.cpp b/csrc/notify_dispatch/op_kernel/notify_dispatch.cpp deleted file mode 100644 index d641e1fa..00000000 --- a/csrc/notify_dispatch/op_kernel/notify_dispatch.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "kernel_operator.h" -#include "notify_dispatch.h" -#include "notify_dispatch_tiling.h" - -#define TILING_KEY_FLOAT16 20 -#define TILING_KEY_BFLOAT16 21 -#define TILING_KEY_FLOAT 22 -#define TILING_KEY_INT 23 - -#define KERNEL_USE_WORKSPACE (1 * 1024 * 1024) - -extern "C" __global__ __aicore__ void notify_dispatch( - GM_ADDR sendData, GM_ADDR tokenPerExpertData, GM_ADDR sendDataOffset, GM_ADDR recvData, GM_ADDR workspace, GM_ADDR tiling) -{ - REGISTER_TILING_DEFAULT(NotifyDispatchTilingData); - GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tiling); - - // hcomm will set magic later in init - uint32_t magic = 1; - GM_ADDR commArgs = nullptr; - - int localRank = tilingData.notifyDispatchInfo.localRankId; - int localRankSize = tilingData.notifyDispatchInfo.localRankSize; - int rank = tilingData.notifyDispatchInfo.rankId; - int rankSize = tilingData.notifyDispatchInfo.rankSize; - int64_t len = tilingData.notifyDispatchInfo.sendCount; - int64_t numTokens = tilingData.notifyDispatchInfo.numTokens; - - GM_ADDR sendDataInput = sendData; - GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; - GM_ADDR sendDataOffsetOutput = sendDataOffset; - GM_ADDR recvDataOutput = recvData; - - // fill in unused args - uint32_t extraFlag = 0; - GM_ADDR scale = nullptr; - int root = 0; - int op = 0; - int cycleCount = 0; - int64_t scaleCount = 0; - GM_ADDR offset = nullptr; - int blockNum = GetBlockNum(); - - if (TILING_KEY_IS(TILING_KEY_FLOAT16)) { - NotifyDispatch opKernel(rank, rankSize, extraFlag); - opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); - opKernel.Process(); - } else if (TILING_KEY_IS(TILING_KEY_FLOAT)) { - NotifyDispatch opKernel(rank, rankSize, extraFlag); - opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); - opKernel.Process(); - } else if (TILING_KEY_IS(TILING_KEY_INT)) { - NotifyDispatch opKernel(rank, rankSize, extraFlag); - opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); - opKernel.Process(); - } -} \ No newline at end of file diff --git a/csrc/notify_dispatch/op_kernel/notify_dispatch.h b/csrc/notify_dispatch/op_kernel/notify_dispatch.h deleted file mode 100644 index 1952a6c3..00000000 --- a/csrc/notify_dispatch/op_kernel/notify_dispatch.h +++ /dev/null @@ -1,495 +0,0 @@ -#ifndef NOTIFY_DISPATCH_H -#define NOTIFY_DISPATCH_H - -#include -#include "kernel_operator.h" - -#include "../common/comm_args.h" -#include "../common/data_copy.h" -#include "../common/sync_collectives.h" -#include "../common/moe_distribute_base.h" - -using namespace AscendC; -using namespace Moe; - -#define KERNELS_ARGS_FUN_ALL2ALL() \ - GM_ADDR sendDataInput, GM_ADDR tokenPerExpertDataInput, GM_ADDR sendDataOffsetOutput, GM_ADDR recvDataOutput, \ - int64_t len, int64_t numTokens, int op, int root, int cycleCount, GM_ADDR scale, int64_t scaleCount, \ - GM_ADDR offset, int localRank, int localRankSize, GM_ADDR commArgs, int magic - -#define KERNELS_ARGS_CALL_ALL2ALL() \ - sendDataInput, tokenPerExpertDataInput, sendDataOffsetOutput, recvDataOutput, len, numTokens, op, root, \ - cycleCount, scale, scaleCount, offset, localRank, localRankSize, commArgs, magic - -template -class NotifyDispatch { - constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; // Invalid rank - constexpr static int64_t CORE_NUMS_PER_STAGE_X = 24; // Maximum number of cores provided by the producer stage - constexpr static int64_t CORE_NUMS_PER_STAGE_Y = 16; // Maximum number of cores provided by the consumer stage - constexpr static int64_t CORE_NUMS_PER_STAGE_Z = 16; // Maximum number of cores provided by the consumer stage 2 - constexpr static int64_t SHARE_QUE_DEPTH = 1; // Depth of a single shared queue - constexpr static int64_t RANK_NUM_PER_NODE = 16; - constexpr static int64_t SIO_NUM = 2; // Depth of a single shared queue - constexpr static int64_t MAX_CORE_NUM = 48; - constexpr static int64_t MAX_RANK_PER_CORE = 8; - constexpr static int64_t MULTI_RANK_SIZE = 48; - constexpr static int64_t MAX_BUFFER_NUMBER = 10; - - constexpr static int64_t IDLER_CORE = 0; // Idle core - constexpr static int64_t PRODUCER_CORE = 1; // Producer group, responsible for writing data to shared memory, input->share, or share->share - constexpr static int64_t CONSUMER_CORE = 2; // Consumer group, responsible for reading data from shared memory, share->output - constexpr static int64_t CONSUMER_CORE2 = 3; - -public: - __aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag) - : rank(rank), rankSize(rankSize), extraFlag(extraFlag) - {} - - __aicore__ inline void Init(KERNELS_ARGS_FUN_ALL2ALL()) - { - InitSmallFullMesh(KERNELS_ARGS_CALL_ALL2ALL()); - nodeNum = rankSize / localRankSize; - localRankId = rank % localRankSize; - localNodeId = rank / localRankSize; - perNodeDataNum = GetDataCount(len, nodeNum); // 128K/4 = 32K - perRankDataNum = GetDataCount(len, rankSize); // 128K/64 = 2K - - tokenPerExpertDataAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataOffsetAlignLen = Ceil(numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataAlignLen = Ceil(numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - - // Initialize core grouping - InitCoreGroup(); - // Initialize data slicing - InitDataSlice(); - - this->sendDataInput = (__gm__ T *)sendDataInput; - this->tokenPerExpertDataInput = (__gm__ int32_t *)tokenPerExpertDataInput; - this->sendDataOffsetOutput = (__gm__ T *)sendDataOffsetOutput; - this->recvDataOutput = (__gm__ T *)recvDataOutput; - sendDataInputGt.SetGlobalBuffer((__gm__ T *)sendDataInput); - tokenPerExpertDataInputGt.SetGlobalBuffer((__gm__ int32_t *)tokenPerExpertDataInput); - sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput); - recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput); - } - - __aicore__ inline void Process() - { - if (blockIdx < 1) { - AssembleSendData(); - } - SyncAll(); - if (blockIdx < coreNumPerStageX) { - InputToShareSlice(); - } - if (blockIdx < coreNumPerStageY) { - ShareToShareSlice(); - } - } - -private: - __aicore__ inline void InitCoreGroup() - { - coreNumPerStageY = MAX_CORE_NUM; - coreNumPerStageX = MAX_CORE_NUM; - rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM; - } - - __aicore__ inline void InitDataSlice() - { - // The producer is responsible for moving the input data of this rank to shared memory, input-->share - if (blockIdx < coreNumPerStageX) { - ProducerDataSlice(); - } - } - - __aicore__ inline void ProducerDataSlice() - { - // The ipcQue responsible for the current core - writeGt.SetGlobalBuffer((__gm__ T *)(shareAddrs[rank] + IPC_DATA_OFFSET)); - } - - __aicore__ inline void AssembleSendData() - { - pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen); - pipe.InitBuffer(sendDataBuf, sendDataAlignLen); - pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen); - - __ubuf__ int32_t *tokenPerExpertUB = (__ubuf__ int32_t *)get_imm(96); - CpGM2UB(tokenPerExpertUB, (__gm__ int32_t *)tokenPerExpertDataInputGt.GetPhyAddr(), tokenPerExpertDataAlignLen); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - - __ubuf__ T *sendDataOffsetUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen); - __ubuf__ T *sendDataUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen + sendDataOffsetAlignLen); - - int prefixSum = 0; - for (int i = 0; i < numExperts; ++i) { - int numTokensExpert = tokenPerExpertUB[i]; - sendDataUB[i * sendPerGroup] = numTokensExpert; - sendDataUB[i * sendPerGroup + 1] = prefixSum; - sendDataUB[i * sendPerGroup + 2] = numTokens; - sendDataOffsetUB[i] = prefixSum; - - prefixSum += numTokensExpert; - } - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - - CpUB2GM((__gm__ T *)sendDataInputGt.GetPhyAddr(), sendDataUB, sendDataAlignLen); - CpUB2GM((__gm__ T *)sendDataOffsetOutputGt.GetPhyAddr(), sendDataOffsetUB, sendDataOffsetAlignLen); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - } - - // copy input to other rank share - __aicore__ inline void InputToShareSlice() - { - __ubuf__ int64_t *inputUB = (__ubuf__ int64_t *)get_imm(0); - int64_t copyOffset = blockIdx * rankNumPerCore; - copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; - if (copyLen > 0) { - readGt = sendDataInputGt[copyOffset * perRankDataNum]; - CpGM2GMPingPong( - copyLen * perRankDataNum * sizeof(T), readGt, writeGt[copyOffset * perRankDataNum], COPYONLY); - int64_t v = MergeMagicWithValue(magic, 1); - *inputUB = v; - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - for (int i = copyOffset; i < copyOffset + copyLen; ++i) { - CpUB2GM((__gm__ int64_t *)(shareAddrs[i]) + rank * FLAG_UNIT_INT_NUM, inputUB, sizeof(int64_t)); - } - pipe_barrier(PIPE_ALL); - } - } - - __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value) - { - // magic as the high part, eventID as the low part, combined into a value for comparison - return (static_cast(static_cast(magic)) << MAGIC_OFFSET) | static_cast(value); - } - - __aicore__ inline void ShareToShareSlice() - { - __ubuf__ T *inputUB = (__ubuf__ T *)get_imm(96); - int64_t copyOffset = blockIdx * rankNumPerCore; - copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; - if (copyLen > 0) { - int checkRank[MAX_RANK_PER_CORE]; - for (int i = copyOffset; i < copyOffset + copyLen; ++i) { - checkRank[i - copyOffset] = i + rank % copyLen; - if (checkRank[i - copyOffset] >= copyOffset + copyLen) { - checkRank[i - copyOffset] -= copyLen; - } - } - for (int i = 0; i < copyLen; i++) { - readGt1[i].SetGlobalBuffer((__gm__ T *)(shareAddrs[checkRank[i]] + IPC_DATA_OFFSET)); - } - sync.WaitSyncFlag(magic, 1, copyOffset, rank, copyLen); - for (int i = 0; i < copyLen; i++) { - CpGM2GMPingPong(perRankDataNum * sizeof(T), - readGt1[i][rank * perRankDataNum], - recvDataOutputGt[checkRank[i] * perRankDataNum], - COPYONLY); - } - } - } - - FORCE_INLINE_AICORE int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum); - __aicore__ inline GM_ADDR GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx); - __aicore__ inline int32_t GetMagicValue(void); - FORCE_INLINE_AICORE void InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL()); - template - FORCE_INLINE_AICORE void SetAtomic(int op); - FORCE_INLINE_AICORE void UnsetAtomic(int op); - template - FORCE_INLINE_AICORE void SetWaitEvent(event_t eventId); - template - FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor& sendDataInputGt, - const GlobalTensor& recvDataOutputGT, int op); - - GlobalTensor sendDataInputGt; - GlobalTensor tokenPerExpertDataInputGt; - GlobalTensor sendDataOffsetOutputGt; - GlobalTensor recvDataOutputGt; - GlobalTensor readGt; - GlobalTensor writeGt; - GlobalTensor readGt1[MAX_BUFFER_NUMBER]; - GlobalTensor ipcGT; - GlobalTensor sendCountMatrixGm; - __gm__ T *sendDataInput; - __gm__ int *tokenPerExpertDataInput; - __gm__ T *sendDataOffsetOutput; - __gm__ T *recvDataOutput; - int64_t isPad = 0; - int64_t maxSliceNum; - int64_t revLen = 0; - int64_t sendLen = 0; - int64_t sliceLen; - int64_t perNodeDataNum; - int64_t perRankDataNum; - int64_t curRankDataNum; - int64_t sendOffset[MULTI_RANK_SIZE]; - int64_t revOffset[MULTI_RANK_SIZE]; - int64_t inputDataLen[MULTI_RANK_SIZE]; - - int64_t nodeNum; - int64_t localRankId; - int64_t localNodeId; - int64_t targetNode; - int64_t targetLocalRankIds[2]; - int64_t queLen; - int64_t queSize; - int64_t coreNumPerStageX; // Number of cores used per stage - int64_t coreNumPerStageY; // Number of cores used per stage - int64_t coreNumPerStageZ; // Number of cores used per stage - int64_t flagNumPerStage; // Number of synchronization flags used per stage - int64_t coreNumPerNode; // Number of cores allocated per node - int64_t coreNumPerRank; // Number of cores allocated per rank - int64_t rankNumPerCore; // Number of ranks responsible per core - int64_t coreGroup; // Functional group of the current core - int64_t targetRank[MULTI_RANK_SIZE]; // Ranks responsible by the current core - int64_t targetRankX; - int64_t targetRankY; - - int64_t queElemLen; // Size of each element in the shared memory queue (in terms of T) - - int64_t copyLen; // Length of the current data slice being copied (in terms of T) - - // for coll - int rank; - int rankSize; - int localRank = 0; - int localRankSize = 0; - int xRankSize = 0; - int yRankSize = 0; - int xRankIdx = 0; - int yRankIdx = 0; - uint32_t extraFlag; - int numTokens; - int sendPerGroup = 3; - int root; - int64_t len; - int64_t numExperts; - int64_t magic; - int64_t blockIdx; // Index of the current aicore - int64_t blockNum; // Total number of aicores for the current rank - int32_t numRanks; - int64_t timeout; - uint16_t *rootRanks; - GM_ADDR scale; - GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses - __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; - Hccl hccl_; - GlobalTensor peerMemsAddrGm_; - GlobalTensor dfx; - TPipe pipe; - TBuf tBuf; - TBuf<> tokenPerExpertDataBuf; - TBuf<> sendDataOffsetBuf; - TBuf<> sendDataBuf; - - uint32_t sendDataAlignLen{0}; - uint32_t tokenPerExpertDataAlignLen{0}; - uint32_t sendDataOffsetAlignLen{0}; - - SyncCollectives sync; -}; - -template -FORCE_INLINE_AICORE int64_t NotifyDispatch::GetDataCount(const int64_t dataLen, const int64_t useBlockNum) -{ - return dataLen / useBlockNum; -} - -template -__aicore__ inline GM_ADDR NotifyDispatch::GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx) -{ - uint32_t curRankId = rank; -#ifdef OPT_RANK_OFFSET -#pragma message("use rank offset") - if (curRankId == rankId) { - return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + rankId * OPT_RANK_OFFSET; - } - return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) + - rankId * OPT_RANK_OFFSET; -#else - if (curRankId == rankId) { - return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn); - } - return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn); -#endif -} - -// Assign values to winContext_[COMM_EP_IDX] and blockIdx before calling -template -__aicore__ inline int32_t NotifyDispatch::GetMagicValue(void) -{ - int32_t magic = 0; - GlobalTensor selfDataStatusTensor; - GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp); - selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); - DataCacheCleanAndInvalid( - selfDataStatusTensor[blockIdx * UB_ALIGN_SIZE]); - magic = selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE); - if (magic <= 0) { - magic = 1; - } - selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE) = magic + 1; - return magic; -} - -template -FORCE_INLINE_AICORE void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL()) -{ - this->root = root; - this->len = len; - this->numExperts = len / sendPerGroup; - this->numTokens = numTokens; - this->scale = scale; - this->localRank = localRank; - this->localRankSize = localRankSize; - this->xRankSize = localRankSize; - this->yRankSize = rankSize / localRankSize; - this->xRankIdx = rank % localRankSize; - this->yRankIdx = rank / localRankSize; - blockIdx = GetBlockIdx(); - blockNum = GetBlockNum(); - uint8_t ctxIdx; - - winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); - this->magic = GetMagicValue(); - ctxIdx = COMM_EP_IDX; - - shareAddrs[rank] = GetWindAddrByRankId(rank, ctxIdx) + - (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); - - int64_t rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM; - int64_t copyOffset = blockIdx * rankNumPerCore; - int64_t copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; - if (copyLen > 0) { - for (int i = copyOffset; i < copyOffset + copyLen; ++i) { - shareAddrs[i] = GetWindAddrByRankId(i, ctxIdx) + - (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); - } - } - - // When the number of cores is more than the number of ranks, each core is responsible for fetching data from a specified rank - int coreNumPerRank = blockNum / rankSize; // Calculate the number of cores assigned to read for each rank, e.g., 48 cores 4 ranks, each rank is assigned 12 cores - int maxCore = coreNumPerRank * rankSize; // Calculate the maximum number of cores that can be used for reading, cores exceeding this number will not take action - if (blockIdx < maxCore) { - int readRank = blockIdx / coreNumPerRank; // Calculate the rank to be read based on the block, 48 cores divided into 4 groups - shareAddrs[readRank] = GetWindAddrByRankId(readRank, ctxIdx) + - (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); - } - - pipe.InitBuffer(tBuf, UB_SINGLE_TOTAL_SIZE_MAX); - - sync.Init(rank, rankSize, shareAddrs, tBuf); -} - -/** - * @brief Copy data from GM to GM with ping-pong method. - * @tparam dataSizeRemain The remaining size of data to be copied. - * @tparam K The type of output data. - * @tparam U The type of input data. - * @param sendDataInputGt The global tensor of send data. - * @param recvDataOutputGT The global tensor of recv data. - * @param op The operation to be performed during the copy. - * @details This function copies data from global memory to global memory using a ping-pong method. - * It first checks if the input and output types are the same. If they are, it uses a single buffer. - * If they are not, it divides the buffer according to the size ratio of the types and aligns it to 32 bytes. - * Then, it sets the atomic operation, waits for the flags, and performs the copy operation. - */ -template -template -FORCE_INLINE_AICORE void NotifyDispatch::CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor& sendDataInputGt, - const GlobalTensor& recvDataOutputGT, int op) -{ - // General case (U = K), input/output are the same, share one UB - // Only when conversion is needed (U->K), UB will be divided into two parts according to the ratio of sizeof(U):sizeof(K) and aligned to 32 bytes - constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; - constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(K) + sizeof(U)) / UB_ALIGN_SIZE * UB_ALIGN_SIZE; - constexpr int32_t inputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(U); - constexpr int32_t outputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(K); - - __gm__ U *input = const_cast<__gm__ U *>(sendDataInputGt.GetPhyAddr()); - __gm__ K *output = const_cast<__gm__ K *>(recvDataOutputGT.GetPhyAddr()); - __ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET), (__ubuf__ U*)(UB_MID_OFFSET)}; - __ubuf__ K* outputUB[2] = {(__ubuf__ K*)inputUB[0], (__ubuf__ K*)inputUB[1]}; - if constexpr (!std::is_same_v) { - outputUB[0] = (__ubuf__ K*)(inputUB[0] + inputUbBlockSize / sizeof(U)); - outputUB[1] = (__ubuf__ K*)(inputUB[1] + inputUbBlockSize / sizeof(U)); - } - int inputOffsetNum = 0; - int outputOffsetNum = 0; - if (dataSizeRemain <= 0) { - return; - } - - SetAtomic(op); - - AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 - AscendC::SetFlag(EVENT_ID1); // MTE2 waits for MTE3 - for (int64_t i = 0; dataSizeRemain > 0; i++) { - // size and dataSizeRemain both refer to the output size - uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain; - event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; - AscendC::WaitFlag(eventId); - CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(K) * sizeof(U)); - if constexpr (!std::is_same_v) { - SetWaitEvent(eventId); - CastImpl((i & 1) ? outputUB[0] : outputUB[1], (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, - size / sizeof(K)); - SetWaitEvent(eventId); - } - AscendC::SetFlag(eventId); - AscendC::WaitFlag(eventId); - CpUB2GM(output + outputOffsetNum, (i & 1) ? outputUB[0] : outputUB[1], size); - AscendC::SetFlag(eventId); - - dataSizeRemain -= size; - inputOffsetNum += (size / sizeof(K)); - outputOffsetNum += (size / sizeof(K)); - } - AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 - AscendC::WaitFlag(EVENT_ID1); // MTE2 waits for MTE3 - - AscendC::SetFlag(EVENT_ID3); // Scalar waits for MTE3 - AscendC::WaitFlag(EVENT_ID3); - - UnsetAtomic(op); - return; -} - -template -template -FORCE_INLINE_AICORE void NotifyDispatch::SetAtomic(int op) -{ - PipeBarrier(); - if (op != -1) { -#ifdef __DAV_C220_VEC__ - SetAtomicOpType(op); -#endif - } - PipeBarrier(); -} - -template -FORCE_INLINE_AICORE void NotifyDispatch::UnsetAtomic(int op) -{ - if (op != -1) { - AscendC::SetAtomicNone(); - } - PipeBarrier(); -} - -template -template -FORCE_INLINE_AICORE void NotifyDispatch::SetWaitEvent(event_t eventId) -{ - AscendC::SetFlag(eventId); - AscendC::WaitFlag(eventId); -} - -#endif // NOTIFY_DISPATCH_H diff --git a/csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h b/csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h deleted file mode 100644 index 5a00b188..00000000 --- a/csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef NOTIFY_DISPATCH_TILING_H -#define NOTIFY_DISPATCH_TILING_H - -#include "kernel_tiling/kernel_tiling.h" - -struct NotifyDispatchInfo { - uint32_t rankSize; - uint32_t rankId; - uint32_t localRankSize; - uint32_t localRankId; - uint32_t sendCount; - uint32_t numTokens; - uint32_t aivNum; - uint64_t totalUbSize; -}; - -struct NotifyDispatchTilingData { - Mc2InitTiling mc2InitTiling; - Mc2CcTiling mc2CcTiling1; - NotifyDispatchInfo notifyDispatchInfo; -}; - -#endif \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 37efc3c2..706b711b 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include "torch_npu/csrc/core/npu/NPUGuard.h" #include #include "acl/acl.h" @@ -809,246 +808,6 @@ at::Tensor npu_sparse_flash_attention( return output; } -std::tuple get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts, - int64_t num_ranks) { - TORCH_BIND_ASSERT(topk_idx.dim() == 2); - TORCH_BIND_ASSERT(topk_idx.is_contiguous()); - TORCH_BIND_ASSERT(num_experts > 0); - - const int num_tokens = topk_idx.size(0); - const int num_topk = topk_idx.size(1); - - auto device = topk_idx.device(); - auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); - auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); - auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); - - EXEC_NPU_CMD(aclnnDispatchLayout, - topk_idx, - num_tokens, - num_ranks, - num_experts, - num_topk, - num_tokens_per_rank, - num_tokens_per_expert, - is_token_in_rank); - - auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); - - return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool); -} - -std::tuple dispatch_prefill( - const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, - const at::Tensor& num_tokens_per_rank, const at::Tensor& is_token_in_rank, at::Tensor& num_tokens_per_expert, - int64_t num_worst_tokens, c10::string_view groupEp, int64_t rank, int64_t num_ranks) { - std::vector group_ep_chrs(groupEp.begin(), groupEp.end()); - group_ep_chrs.push_back('\0'); - char* group_ep_ptr = &group_ep_chrs[0]; - at::Tensor new_x = x; - - // Type checks - TORCH_BIND_ASSERT(is_token_in_rank.scalar_type() == at::kBool); - TORCH_BIND_ASSERT(num_tokens_per_expert.scalar_type() == at::kInt); - TORCH_BIND_ASSERT(num_tokens_per_rank.scalar_type() == at::kInt); - - // Shape and contiguous checks - TORCH_BIND_ASSERT(new_x.dim() == 2 and new_x.is_contiguous()); - // TORCH_BIND_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - TORCH_BIND_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); - TORCH_BIND_ASSERT(is_token_in_rank.size(0) == new_x.size(0) and is_token_in_rank.size(1) == num_ranks); - TORCH_BIND_ASSERT(num_tokens_per_expert.dim() == 1 and num_tokens_per_expert.is_contiguous()); - TORCH_BIND_ASSERT(num_tokens_per_expert.size(0) % num_ranks == 0); - TORCH_BIND_ASSERT(num_tokens_per_rank.dim() == 1 and num_tokens_per_rank.is_contiguous()); - TORCH_BIND_ASSERT(num_tokens_per_rank.size(0) == num_ranks); - - auto num_tokens = static_cast(new_x.size(0)); - auto hidden = static_cast(new_x.size(1)); - auto num_experts = static_cast(num_tokens_per_expert.size(0)); - auto num_local_experts = static_cast(num_experts / num_ranks); - - // Top-k checks - int num_topk = 0; - num_topk = static_cast(topk_idx.size(1)); - TORCH_BIND_ASSERT(num_experts > 0); - TORCH_BIND_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); - TORCH_BIND_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); - TORCH_BIND_ASSERT(num_tokens == topk_idx.size(0)); - TORCH_BIND_ASSERT(num_topk == topk_weights.size(1)); - TORCH_BIND_ASSERT(topk_weights.scalar_type() == at::kFloat); - - int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens) - - auto send_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device())); - int64_t send_count = send_per_group * num_local_experts * num_ranks; - - auto send_data_offset = at::empty({num_experts}, at::dtype(at::kInt).device(x.device())); - at::Tensor recv_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device())); - - int64_t local_rank_size = num_ranks; - int64_t local_rank_id = rank % local_rank_size; - - EXEC_NPU_CMD(aclnnNotifyDispatch, - send_data, - num_tokens_per_expert, - send_count, - num_tokens, - group_ep_ptr, // commGroup - num_ranks, // rankSize - rank, // rankId - local_rank_size, - local_rank_id, - send_data_offset, - recv_data); - - auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - std::vector local_expert_acc(num_experts, 0); - auto send_token_idx_cpu = at::empty({num_tokens, num_topk}, options_cpu); - auto send_token_idx_ptr = send_token_idx_cpu.data_ptr(); - - auto topk_idx_cpu = topk_idx.to(at::kCPU); - auto topk_idx_ptr = topk_idx_cpu.data_ptr(); - for (int i = 0; i < num_tokens; ++i) { - for (int j = 0; j < num_topk; ++j) { - int64_t expert_idx = topk_idx_ptr[i * num_topk + j]; - if (expert_idx >= 0) { - int32_t cnt = local_expert_acc[expert_idx]; - send_token_idx_ptr[i * num_topk + j] = cnt; - local_expert_acc[expert_idx]++; - } - } - } - - TORCH_BIND_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous()); - TORCH_BIND_ASSERT(recv_data.size(0) % num_experts == 0); - at::Tensor recv_offset_cpu = at::empty({num_experts}, options_cpu); - at::Tensor recv_count_cpu = at::empty({num_experts}, options_cpu); - auto recv_data_cpu = recv_data.to(at::kCPU); - auto recv_data_ptr = recv_data_cpu.data_ptr(); - auto recv_count_ptr = recv_count_cpu.data_ptr(); - auto recv_offset_ptr = recv_offset_cpu.data_ptr(); - int64_t total_recv_tokens = 0; - int64_t num_max_dispatch_tokens_per_rank = 0; - std::vector num_recv_tokens_per_expert_list; - - for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) { - int64_t local_expert_recv_tokens = 0; - for (int64_t src_rank = 0; src_rank < num_ranks; ++src_rank) { - int64_t index = local_e * num_ranks + src_rank; - int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e); - - int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for - // this global_expert - int recv_off = recv_data_ptr[pair_idx + 1]; // offset in that src_rank's window - int64_t send_num_tokens = recv_data_ptr[pair_idx + 2]; // all bs from rank - - total_recv_tokens += recv_cnt; - recv_count_ptr[index] = total_recv_tokens; - recv_offset_ptr[index] = recv_off; - num_max_dispatch_tokens_per_rank = std::max(num_max_dispatch_tokens_per_rank, send_num_tokens); - - local_expert_recv_tokens += recv_cnt; - } - num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens); - } - auto option = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); - at::Tensor num_recv_tokens_per_expert = torch::from_blob( - num_recv_tokens_per_expert_list.data(), {static_cast(num_recv_tokens_per_expert_list.size())}, option) - .clone(); - - at::Tensor expert_ids = topk_idx.to(at::kInt); - int64_t tp_size = 1; - int64_t tp_rank = 0; - int64_t quant_mode = 0; - int64_t global_bs = static_cast( - std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast(num_worst_tokens))); - - auto send_token_idx = send_token_idx_cpu.to(x.device()); - auto recv_offset = recv_offset_cpu.to(x.device()); - auto recv_count = recv_count_cpu.to(x.device()); - - int total_cnt = total_recv_tokens; - if (total_cnt == 0) { - total_cnt = 1; - } - auto expandx_out = at::empty({total_cnt, hidden}, x.options()); - auto dynamic_scales_out = at::empty({total_cnt}, at::dtype(at::kFloat).device(x.device())); - auto expand_idx_out = at::empty({total_cnt * 3}, at::dtype(at::kInt).device(x.device())); - - EXEC_NPU_CMD(aclnnMoeDispatchNormal, - new_x, - expert_ids, - send_data_offset, - send_token_idx, - recv_offset, - recv_count, - group_ep_ptr, // commGroup - num_ranks, // rankSize - rank, // rankId - group_ep_ptr, - tp_size, - tp_rank, - num_experts, - quant_mode, - global_bs, - expandx_out, - dynamic_scales_out, - expand_idx_out); - - // Return values - return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert}; -} - -at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, - const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp, - int64_t rank, int64_t num_ranks) { - std::vector group_ep_chrs(groupEp.begin(), groupEp.end()); - group_ep_chrs.push_back('\0'); - char* group_ep_ptr = &group_ep_chrs[0]; - - TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous()); - at::Tensor recv_x = x; - - at::Tensor topk_idx_p = topk_idx; - - auto topk_idx_int32 = topk_idx_p.to(at::kInt); - at::Tensor expand_ids = topk_idx_int32; - at::Tensor token_src_info = src_idx; - at::Tensor ep_send_counts = send_head; - auto device = x.device(); - - const int num_tokens = topk_idx_p.size(0); - const int num_topk = topk_idx_p.size(1); - - int64_t hidden = static_cast(recv_x.size(1)); - at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device)); - int64_t tp_world_size = 1; - int64_t tp_rankId = 0; - int64_t moe_expert_number = send_head.size(0); - int64_t global_bs = topk_idx_p.size(0) * num_ranks; - - // Combine data - auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options()); - - EXEC_NPU_CMD(aclnnMoeCombineNormal, - recv_x, - token_src_info, - ep_send_counts, - topk_weights, - tp_send_counts, - group_ep_ptr, - num_ranks, - rank, - group_ep_ptr, - tp_world_size, - tp_rankId, - moe_expert_number, - global_bs, - combined_x); - - return combined_x; -} - } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -1162,25 +921,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int max_output_size, Tensor! out) -> Tensor" ); ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine); - - ops.def("get_dispatch_layout(Tensor topk_idx, int num_experts, int " - "num_ranks) -> (Tensor num_tokens_per_rank, Tensor " - "num_tokens_per_expert, Tensor is_token_in_rank_bool)"); - ops.impl("get_dispatch_layout", torch::kPrivateUse1, - &vllm_ascend::get_dispatch_layout); - - ops.def( - "dispatch_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, " - "Tensor num_tokens_per_rank, Tensor is_token_in_rank, Tensor " - "num_tokens_per_expert, int num_worst_tokens, str groupEp, int rank, " - "int num_ranks) -> (Tensor expandx_out, Tensor expand_idx_out, Tensor " - "recv_count, Tensor num_recv_tokens_per_expert)"); - ops.impl("dispatch_prefill", torch::kPrivateUse1, - &vllm_ascend::dispatch_prefill); - - ops.def("combine_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, " - "Tensor src_idx, Tensor send_head, str grouEp, int rank, int " - "num_ranks) -> Tensor"); - ops.impl("combine_prefill", torch::kPrivateUse1, - &vllm_ascend::combine_prefill); } diff --git a/csrc/utils.h b/csrc/utils.h index a692b87f..74481e1b 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -28,28 +28,4 @@ return PyModule_Create(&module); \ } -class TrochBindException : public std::exception -{ -private: - std::string message = {}; -public: - explicit TrochBindException(const char *name, const char *file, const int line, const std::string &error) - { - message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + - " error message or error code is '" + error + "'"; - } - - const char *what() const noexcept override - { - return message.c_str(); - } -}; - -#define TORCH_BIND_ASSERT(cond) \ - ; \ - do { \ - if (not(cond)) { \ - throw TrochBindException("Assertion", __FILE__, __LINE__, #cond); \ - } \ - } while (0) diff --git a/csrc/utils/inc/kernel/comm_args.h b/csrc/utils/inc/kernel/comm_args.h deleted file mode 100644 index 3aadb840..00000000 --- a/csrc/utils/inc/kernel/comm_args.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef COMM_ARGS_H -#define COMM_ARGS_H -#include - -#define FORCE_INLINE_AICORE __attribute__((always_inline)) inline __aicore__ -#include "kernel_operator.h" - -namespace Moe { -constexpr int CAM_MAX_RANK_SIZE = 384; // Maximum number of NPU cards supported by the communication library - -constexpr int64_t IPC_BUFF_MAX_SIZE = 100 * 1024 * 1024; -constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // First 2MB as flag, then 100MB as data storage -constexpr int64_t PING_PONG_SIZE = 2; -constexpr int64_t UB_SINGLE_DMA_SIZE_MAX = 190 * 1024; -constexpr int64_t SMALL_DATA_SIZE = 1 * 1024 * 1024; -constexpr int64_t UB_SINGLE_PING_PONG_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX / 2; -constexpr int UB_ALIGN_SIZE = 32; -constexpr int64_t MAGIC_ALIGN_COUNT = UB_ALIGN_SIZE / sizeof(int32_t); - -constexpr uint8_t COMM_NUM = 2; // Size of communication domain -constexpr uint8_t COMM_EP_IDX = 0; -constexpr uint8_t COMM_TP_IDX = 1; - -constexpr int DFX_COUNT = 50; -constexpr int64_t WAIT_SUCCESS = 112233445566; -constexpr int64_t IPC_CHUNK_FLAG = 0; // Start offset for send recv, chunk flag region -constexpr int64_t MAX_WAIT_ROUND_UNIT = 10 * 1000 * 1000; // Threshold for waiting to get Flag under normal conditions within the same SIO - -constexpr static int32_t UB_HEAD_OFFSET = 96; -constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + UB_ALIGN_SIZE; -constexpr static int64_t UB_FLAG_SIZE = 2 * 1024; -constexpr static int64_t MAX_CORE_NUM = 48; -constexpr static uint64_t STATE_WIN_OFFSET = 900 * 1024; -constexpr static int64_t COMPARE_ALIGN_SIZE = 256; - -constexpr static int64_t UB_SINGLE_TOTAL_SIZE_MAX = 192 * 1024; -constexpr static int64_t START_OFFSET_FOR_SHARE = 512; - -enum Op : int { - COPYONLY = -1, - ADD = 0, - MUL = 1, - MAX = 2, - MIN = 3 -}; - -struct CommArgs { - int rank = 0; // attr rank_id, global rank - int localRank = -1; - int rankSize = 0; // global rank size - int localRankSize = -1; // This parameter refers to the number of cards interconnected in fullmesh - uint32_t extraFlag = 0; // 32 bit map, the specific meaning of each bit is above in this file - int testFlag = 0; - GM_ADDR peerMems[CAM_MAX_RANK_SIZE] = {}; // Buffer obtained from initialization, all allreduce is the same parameter - /** - * @param sendCountMatrix One-dimensional array with a size of rankSize*rankSize - * eg: The value of sendCountMatrix[1] corresponds to the [0][1] of the two-dimensional array, indicating the number of data that card 0 needs to send to card 1 - */ - int64_t sendCountMatrix[CAM_MAX_RANK_SIZE * CAM_MAX_RANK_SIZE] = {}; // for all2allvc - int64_t sendCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv - int64_t sdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv - int64_t recvCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv - int64_t rdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv - int64_t batchSize; - int64_t hiddenSize; - int64_t topk; - int64_t sharedExpertRankNum; - int64_t expertNumPerRank; - int64_t dfx[DFX_COUNT] = {}; -}; -} -#endif // COMM_ARGS_H diff --git a/csrc/utils/inc/kernel/data_copy.h b/csrc/utils/inc/kernel/data_copy.h deleted file mode 100644 index d9490e1c..00000000 --- a/csrc/utils/inc/kernel/data_copy.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef CAM_DATACOPY_GM2GM_H -#define CAM_DATACOPY_GM2GM_H -#include -#include "comm_args.h" - -using namespace AscendC; -using namespace Moe; - -template -FORCE_INLINE_AICORE void SetAtomicOpType(int op) -{ - switch (op) { - case ADD: - AscendC::SetAtomicAdd(); - break; - case MUL: - // Ignore setting the atomic register when performing mul - break; - case MAX: - AscendC::SetAtomicMax(); - break; - case MIN: - AscendC::SetAtomicMin(); - break; - default: - AscendC::SetAtomicNone(); - } -} - -template -FORCE_INLINE_AICORE void CpUB2GM(__gm__ T *gmAddr, __ubuf__ T *ubAddr, uint32_t size) -{ - LocalTensor ubTensor; - GlobalTensor gmTensor; - DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - ubTensor.address_.logicPos = static_cast(TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); - DataCopyPad(gmTensor, ubTensor, dataCopyParams); -} - -template -FORCE_INLINE_AICORE void CpGM2UB(__ubuf__ T *ubAddr, __gm__ T *gmAddr, uint32_t size) -{ - LocalTensor ubTensor; - GlobalTensor gmTensor; - DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); - ubTensor.address_.logicPos = static_cast(TPosition::VECIN); - ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); - gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); - DataCopyPadExtParams padParams; - DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); -} - -template -FORCE_INLINE_AICORE void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, const uint32_t calCount) -{ - LocalTensor srcTensor; - LocalTensor dstTensor; - TBuffAddr srcAddr, dstAddr; - srcAddr.bufferAddr = reinterpret_cast(src); - dstAddr.bufferAddr = reinterpret_cast(dst); - srcTensor.SetAddr(srcAddr); - dstTensor.SetAddr(dstAddr); - DataCopy(dstTensor, srcTensor, calCount); -} - -#endif // CAM_DATACOPY_GM2GM_H \ No newline at end of file diff --git a/csrc/utils/inc/kernel/moe_distribute_base.h b/csrc/utils/inc/kernel/moe_distribute_base.h deleted file mode 100644 index 607a8799..00000000 --- a/csrc/utils/inc/kernel/moe_distribute_base.h +++ /dev/null @@ -1,199 +0,0 @@ -/*! - * \file moe_distribute_base.h - * \brief - */ - -#ifndef MOE_DISTRIBUTE_BASE_H -#define MOE_DISTRIBUTE_BASE_H - -/* system tick: 50MHz */ -#define CAL_US(tick) (((tick) * 2) / 100) - -/* performance macro */ -// #define USE_256_TO_1__ // Enable 256 to 1 -#ifdef USE_256_TO_1__ - #pragma message("use 256 to 1") -#else // 256 to 1 is only used as baseline, not combined with other optimization points - #define USE_FOR_OPT__ // Enable loop optimization loop optimization - #define DISPATCH_USE_WRITE_SHUFFLE__ // Dispatch uses write shuffle - #define USE_TOKEN_COUNT_SPLIT__ // Enable separation of token and count flags token and count flags - #define USE_ONE_CORE_WAIT__ // Enable single core wait - - #ifdef USE_ONE_CORE_WAIT__ - #pragma message("use one core wait") - // Enable single core cumsum calculation - // #define USE_ONE_CORE_GETCUMSUM__ - #endif - #ifdef USE_FOR_OPT__ - #pragma message("use for optimization") - #define FOR_OPT_MAX_BS__ 64 - #define FOR_OPT_MAX_MOE_RANK__ 256 - #endif - // #define COMBINE_USE_DYNAMIC_QUANT // Combine quantization is disabled by default - #define OPT_RANK_OFFSET 512 - #define USE_WRITE_SHUFFLE -#endif - -constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64; -constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19; -constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2; -constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024; - -struct HcclSignalInfo { - uint64_t resId; // EventId when representing event, notifyId when representing notify - uint64_t addr; - uint32_t devId; - uint32_t tsId; - uint32_t rankId; - uint32_t flag; -}; - -struct ListCommon { - uint64_t nextHost; - uint64_t preHost; - uint64_t nextDevice; - uint64_t preDevice; -}; - -struct HcclStreamInfo { - int32_t streamIds; - uint32_t sqIds; - uint32_t cqIds; // Record physical cqId - uint32_t logicCqids; // Record logical cqId -}; - -struct LocalResInfoV2 { - uint32_t streamNum; - uint32_t signalNum; - HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM]; - HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM]; - HcclStreamInfo mainStreamInfo; - HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; // Collective communication AICPU expanded resources - ListCommon nextTagRes; // HccltagLocalResV2 -}; - -enum class rtFloatOverflowMode_t { - RT_OVERFLOW_MODE_SATURATION = 0, - RT_OVERFLOW_MODE_INFNAN, - RT_OVERFLOW_MODE_UNDEF, -}; - -struct AlgoTopoInfo { - uint32_t userRank; // Communication domain RankID - uint32_t userRankSize; // Number of Ranks in communication domain - int32_t deviceLogicId; - bool isSingleMeshAggregation; - uint32_t deviceNumPerAggregation; // Number of Devices in each Module - uint32_t superPodNum; // Total number of super nodes in cluster - uint32_t devicePhyId; - uint32_t topoType; // TopoType - uint32_t deviceType; - uint32_t serverNum; - uint32_t meshAggregationRankSize; - uint32_t multiModuleDiffDeviceNumMode; - uint32_t multiSuperPodDiffServerNumMode; - uint32_t realUserRank; - bool isDiffDeviceModule; - bool isDiffDeviceType; - uint32_t gcdDeviceNumPerAggregation; - uint32_t moduleNum; - uint32_t isUsedRdmaRankPairNum; - uint64_t isUsedRdmaRankPair; - uint32_t pairLinkCounterNum; - uint64_t pairLinkCounter; - uint32_t nicNum; - uint64_t nicList; // Pointer to niclist array - uint64_t complanRankLength; // Bytes occupied by complanRank - uint64_t complanRank; // Pointer - uint64_t bridgeRankNum; // Number of bridgeRank entries - uint64_t bridgeRank; // Pointer - uint64_t serverAndsuperPodRankLength; // Bytes occupied by serverAndsuperPodRank - uint64_t serverAndsuperPodRank; // Pointer -}; - -struct HcclOpConfig { - uint8_t deterministic; // Deterministic computation switch - uint8_t retryEnable; // Whether to retry execution - uint8_t highPerfEnable; - uint8_t padding[5]; // Size needs 64-byte alignment, reduce padding when adding parameters in future - uint8_t linkTimeOut[8]; // Send timeout duration - uint64_t notifyWaitTime; // Timeout duration, same as HCCL_EXEC_TIMEOUTas HCCL_EXEC_TIMEOUT - uint32_t retryHoldTime; - uint32_t retryIntervalTime; - bool interHccsDisable = false; // Enable RDMA switch - rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF; - uint32_t multiQpThreshold = 512; // Minimum data amount threshold for each QP in multi-QP mode -}; - -struct HcclMC2WorkSpace { - uint64_t workSpace; - uint64_t workSpaceSize; -}; - -struct RemoteResPtr { - uint64_t nextHostPtr; - uint64_t nextDevicePtr; -}; - -struct HDCommunicateParams { - uint64_t hostAddr { 0 }; - uint64_t deviceAddr { 0 }; - uint64_t readCacheAddr { 0 }; - uint32_t devMemSize{ 0 }; - uint32_t buffLen{ 0 }; - uint32_t flag{ 0 }; -}; - -struct HcclRankRelationResV2 { - uint32_t remoteUsrRankId; - uint32_t remoteWorldRank; - uint64_t windowsIn; - uint64_t windowsOut; - uint64_t windowsExp; - ListCommon nextTagRes; -}; - -struct HcclOpResParam { - // Local resources - HcclMC2WorkSpace mc2WorkSpace; - uint32_t localUsrRankId; // usrrankid - uint32_t rankSize; // Total number of ranks in communication domain - uint64_t winSize; // Size of each window, may be 0 for static graphs, may be non-zero if dynamic graphs exist in communication domain - uint64_t localWindowsIn; // All F means invalid value - uint64_t localWindowsOut; // All F means invalid value - char hcomId[128]; - // AICore identifies remote window - uint64_t winExpSize; - uint64_t localWindowsExp; - uint32_t rWinStart; // Start position for HcclRankRelationRes - uint32_t rWinOffset; // Size of HcclRemoteRes - uint64_t version; - LocalResInfoV2 localRes; - AlgoTopoInfo topoInfo; - - // External configuration parameters - HcclOpConfig config; - uint64_t hostStateInfo; - uint64_t aicpuStateInfo; - uint64_t lockAddr; - uint32_t rsv[16]; - uint32_t notifysize; // Used in RDMA scenarios, 4B for 910B/910_93, 8B for other chips - uint32_t remoteResNum; // Valid remoteResNum - RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; // Array pointer, points to HcclRankRelationResV2, index is remoteUserRankId - - // communicate retry - HDCommunicateParams kfcControlTransferH2DParams; - HDCommunicateParams kfcStatusTransferD2HParams; - uint64_t tinyMem; // for all2all - uint64_t tinyMemSize; - // Used in zero-copy scenarios - uint64_t zeroCopyHeadPtr; - uint64_t zeroCopyTailPtr; - uint64_t zeroCopyRingBuffer; - uint64_t zeroCopyIpcPtrs[16]; // Save input/output memory addresses of each peer during collective communication - uint32_t zeroCopyDevicePhyId[16]; // Save physical card ID corresponding to each rank - - bool utraceStatusFlag; -}; - -#endif // MOE_DISTRIBUTE_BASE_H \ No newline at end of file diff --git a/csrc/utils/inc/kernel/sync_collectives.h b/csrc/utils/inc/kernel/sync_collectives.h deleted file mode 100644 index 9653e21a..00000000 --- a/csrc/utils/inc/kernel/sync_collectives.h +++ /dev/null @@ -1,426 +0,0 @@ -#ifndef SYNC_COLLECTIVES_H -#define SYNC_COLLECTIVES_H - -#include "comm_args.h" - -using namespace AscendC; -using namespace Moe; - -// Synchronization flag occupies length -constexpr int64_t FLAG_UNIT_INT_NUM = 4; -// Memory size occupied by each synchronization unit (Bytes) -constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t); -// High-order offset when using magic as a comparison value -constexpr int64_t MAGIC_OFFSET = 32; -constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1); - -class SyncCollectives { -public: - __aicore__ inline SyncCollectives() {} - - __aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, TBuf &tBuf) - { - this->rank = rank; - this->rankSize = rankSize; - this->shareAddrs = shareAddrs; - this->blockIdx = GetBlockIdx(); - this->blockNum = GetBlockNum(); - // Length of a single indicator segment - segmentCount = GetBlockNum() * FLAG_UNIT_INT_NUM; - // Initialize the intra-card/inter-card synchronization address corresponding to the current core. - localSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]); - basicSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + GetBlockIdx() * FLAG_UNIT_INT_NUM; - blockOuterSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + segmentCount + GetBlockIdx() * FLAG_UNIT_INT_NUM; - this->tBuf = tBuf; - } - - __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID) - { - int64_t v = MergeMagicWithValue(magic, value); - SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v); - } - - /** - * @brief Set the flag for the specified eventID of the designated card, with the value being a combination of magic and value. - * @param magic The operator batch, which will be combined into the high 32 bits of the flag value to be set. - * @param value The specific value to be set, which will be the low 32 bits of the flag value to be set. - * @param eventID Physically, it is an offset from the shared memory base address (requires scaling, not an absolute value). - * @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs structure, not a global or local id. - * (Local is not applicable in the 91093 scenario, and global is not applicable in the 910B multi-machine scenario.) - */ - __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) - { - int64_t v = MergeMagicWithValue(magic, value); - SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v); - } - - __aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId) - { - return (blockMultiplier * blockNum) + targetCoreId; - } - - /** - * @brief Wait for the flag of the specified eventID on the specified card to become a value - * composed of the combination of magic and value. - * @param magic The operator batch, which will be combined into the high 32 bits of the flag - * value to be wait. - * @param value The specific value to be wait, which will be the low 32 bits of the flag - * value to be wait. - * @param eventID Physically, it is an offset from the shared memory base address (requires - * scaling, not an absolute value). - * @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs - * structure, not a global or local id. (Local is not applicable in the 91093 - * scenario, and global is not applicable in the 910B multi-machine scenario.) - */ - __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) - { - int64_t v = MergeMagicWithValue(magic, value); - WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); - } - - __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID) - { - int64_t v = MergeMagicWithValue(magic, value); - WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); - } - - /** - * @brief Wait for the flags starting from the specified eventID on the specified card to become - * a value composed of the combination of magic and value.
- * Note: [eventID, eventID + flagNum) - */ - __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum) - { - int64_t v = MergeMagicWithValue(magic, value); - WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v); - } - - // Set inner-card synchronization flag (memory A) - __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID) - { - int64_t value = MergeMagicWithValue(magic, eventID); - SetFlag(basicSyncAddr, value); - } - - __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) - { - int64_t value = MergeMagicWithValue(magic, eventID); - SetFlag((__gm__ int64_t*)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value); - } - - // Wait for a single inner-card synchronization flag (memory A) - __aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) - { - int64_t value = MergeMagicWithValue(magic, eventID); - WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value); - } - - // Wait for all inner-card synchronization flags within the entire rank (memory A) - __aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) - { - int64_t value = MergeMagicWithValue(magic, eventID); - WaitOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); - } - - // Check all inner-card synchronization flags within the entire rank (memory A) - __aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) - { - int64_t value = MergeMagicWithValue(magic, eventID); - return CheckOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); - } - - // Set inter-card synchronization flag (memory B) - __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID) - { - int64_t value = MergeMagicWithValue(magic, eventID); - SetFlag(blockOuterSyncAddr, value); - } - - __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) - { - __gm__ int64_t* flagAddr = GetOuterFlagAddr(setRank, setBlock); - int64_t value = MergeMagicWithValue(magic, eventID); - SetFlag(flagAddr, value); - } - - // Wait for a single inter-card synchronization flag (memory B) - __aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) - { - int64_t value = MergeMagicWithValue(magic, eventID); - __gm__ int64_t* flagAddr = GetOuterFlagAddr(waitRank, waitBlock); - WaitOneRankPartFlag(flagAddr, 1, value); - } - - // Wait for all inter-card synchronization flags within the entire rank (memory B) - __aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank) - { - int64_t value = MergeMagicWithValue(magic, eventID); - __gm__ int64_t* flagAddr; - flagAddr = GetOuterFlagAddr(rank, 0); - WaitOneRankPartFlag(flagAddr, blockNum, value); - } - - // Wait for flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B) - __aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum) - { - int64_t value = MergeMagicWithValue(magic, eventID); - __gm__ int64_t* flagAddr; - int waitRank; - for (auto r = 0; r < rankSize; ++r) { - waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores - flagAddr = GetOuterFlagAddr(waitRank, startBlock); - WaitOneRankPartFlag(flagAddr, flagNum, value); - } - } - - // Check flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B) - __aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, - int64_t flagNum) - { - int64_t value = MergeMagicWithValue(magic, eventID); - __gm__ int64_t* flagAddr; - int waitRank; - for (auto r = 0; r < rankSize; ++r) { - waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores - flagAddr = GetOuterFlagAddr(waitRank, startBlock); - if (!CheckOneRankPartFlag(flagAddr, flagNum, value)) { - return false; - } - } - return true; - } - - // Wait for all inter-card synchronization flags for all ranks, full rank synchronization (memory B) - __aicore__ inline void WaitAllRankOuterFlag(int32_t magic, int32_t eventID) - { - WaitAllRankPartOuterFlag(magic, eventID, 0, blockNum); - } - - // Check all inter-card synchronization flags for all ranks, full rank synchronization (memory B) - __aicore__ inline bool CheckAllRankOuterFlag(int32_t magic, int32_t eventID) - { - return CheckAllRankPartOuterFlag(magic, eventID, 0, blockNum); - } - - // Low-level interface, set synchronization flag - __aicore__ inline void SetFlag(__gm__ int64_t* setAddr, int64_t setValue) - { - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - GlobalTensor globalSet; - globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM); - LocalTensor localSet = tBuf.GetWithOffset(1, 0); - localSet.SetValue(0, setValue); - - // Copy global synchronization flag to local - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for SetValue to complete - DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for UB->GM to complete - } - - // Low-level interface, wait for synchronization flag - __aicore__ inline void WaitFlag(__gm__ int64_t* waitAddr, int64_t waitValue) - { - WaitOneRankPartFlag(waitAddr, 1, waitValue); - } - - // Read a flag, return an immediate number - __aicore__ inline int64_t GetFlag(__gm__ int64_t* waitAddr) - { - GlobalTensor globalWait; - globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); - LocalTensor localWait = tBuf.GetWithOffset(1, 0); - // Copy global to local - DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB - - int64_t res = localWait.GetValue(0); - return res; - } - - // Get multiple consecutive synchronization flags within a single card - __aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, - int64_t startBlock, int64_t flagNum) - { - int64_t value = MergeMagicWithValue(magic, eventID); - __gm__ int64_t* flagAddr; - flagAddr = GetOuterFlagAddr(waitRank, startBlock); - WaitOneRankPartFlag(flagAddr, flagNum, value); - } - - // Get synchronization flag within a single card (memory A) - __aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock) - { - return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM); - } - - __aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock) - { - return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM); - } - - // In the rank Chunk Flag area, return success if the destRank chunk Flag value is 0, otherwise fail - __aicore__ inline int64_t GetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout) - { - int64_t value = MergeMagicWithValue(magic, 0); - int64_t status = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) + - IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value, timeout); - return status; - } - - // Set the destRank chunk Flag value in the rank Chunk Flag area to value - __aicore__ inline void SetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t eventId) - { - int64_t value = MergeMagicWithValue(magic, eventId); - SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value); - } - - __aicore__ inline int64_t GetChunkRecvLen(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout) - { - int64_t len = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG + - destRank * FLAG_UNIT_INT_NUM, 0, timeout, true, magic); - return len; - } - -private: - __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value) - { - // Merge magic as the high bits and eventID as the low bits into a value for comparison - return (static_cast(static_cast(magic)) << MAGIC_OFFSET) | static_cast(value); - } - - __aicore__ inline __gm__ int64_t* GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock) - { - return (__gm__ int64_t*)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM; - } - - __aicore__ inline __gm__ int64_t* GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock) - { - return (__gm__ int64_t*)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM; - } - - // Wait for a part of synchronization flags within a rank - __aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue) - { - GlobalTensor globalWait; - globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); - LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); - bool isSync = true; - int64_t checkedFlagNum = 0; - do { - // Copy global synchronization flags to local - DataCopy(localWait, globalWait[checkedFlagNum * FLAG_UNIT_INT_NUM], - (flagNum - checkedFlagNum) * FLAG_UNIT_INT_NUM); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB - - // Check if the synchronization flags are equal to checkValue - isSync = true; - int64_t remainToCheck = flagNum - checkedFlagNum; - for (auto i = 0; i < remainToCheck; ++i) { - // Continue waiting if any core has not reached the checkValue phase - int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); - if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { - isSync = false; - checkedFlagNum += i; - break; - } - } - } while (!isSync); - } - - // Wait for all synchronization flags within a rank - __aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) - { - WaitOneRankPartFlag(waitAddr, blockNum, checkValue); - } - - // Check partial synchronization flags within a rank, copy only once - __aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue) - { - GlobalTensor globalWait; - globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); - LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); - // Copy global synchronization flags to local - DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB - // Check if the synchronization flags are equal to checkValue - bool isSync = true; - for (auto i = 0; i < flagNum; ++i) { - // Continue waiting if any core has not reached the checkValue phase - int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); - if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { - isSync = false; - break; - } - } - return isSync; - } - - __aicore__ inline int64_t GetChunkFlagValue(__gm__ int64_t* waitAddr, int64_t checkValue, int64_t timeout, - bool checkNonZero = false, int64_t magic = 0) - { - GlobalTensor globalWait; - globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); - LocalTensor localWait = tBuf.GetWithOffset(FLAG_UNIT_INT_NUM, 0); - bool isSync = true; - - int64_t waitTimes = 0; - int64_t v = 0; - - do { - // Copy global sync flag to local - DataCopy(localWait, globalWait[0], FLAG_UNIT_INT_NUM); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB - - isSync = true; - v = localWait.GetValue(0); - if (checkNonZero) { - // Non-zero check mode - if (((v & MAGIC_MASK) == (static_cast(magic) << MAGIC_OFFSET)) && (v & 0xFFFFFFFF)) { - return v & 0xFFFFFFFF; // Return lower 32 bits when non-zero - } - } else { - // Exact value check mode - if (v == checkValue) { - return WAIT_SUCCESS; - } - } - - isSync = false; - waitTimes++; - - if (timeout > INT64_MAX / MAX_WAIT_ROUND_UNIT || waitTimes >= (timeout * MAX_WAIT_ROUND_UNIT)) { - isSync = true; - return v; // Return the read flag value - } - } while (!isSync); - - return checkNonZero ? 0 : v; - } - - // Check all sync flags within a rank, copy only once - __aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) - { - return CheckOneRankPartFlag(waitAddr, blockNum, checkValue); - } - int rank; - int rankSize; - int blockIdx; - int blockNum; - GM_ADDR *shareAddrs; - int64_t segmentCount; // Length of a single sync flag segment (count in int64_t) - __gm__ int64_t* localSyncAddr; - __gm__ int64_t* basicSyncAddr; // Intra-card sync flag address for the current block - __gm__ int64_t* blockOuterSyncAddr; // Inter-card sync flag address for the current block - TBuf tBuf; -}; - -#endif // SYNC_COLLECTIVES_H \ No newline at end of file