[Kernel] add custom moe ops for prefill (#4194)
### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.
- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.
2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.
- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
@@ -45,7 +45,19 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
|
||||
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
|
||||
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine;dispatch_gmm_combine_decode;"
|
||||
|
||||
CUSTOM_OPS_ARRAY=(
|
||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
"lightning_indexer"
|
||||
"sparse_flash_attention"
|
||||
"dispatch_ffn_combine"
|
||||
"dispatch_gmm_combine_decode"
|
||||
"moe_combine_normal"
|
||||
"moe_dispatch_normal"
|
||||
"dispatch_layout"
|
||||
"notify_dispatch"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
else
|
||||
# others
|
||||
@@ -58,7 +70,7 @@ fi
|
||||
cd csrc
|
||||
rm -rf build output
|
||||
echo "building custom ops $CUSTOM_OPS for $SOC_VERSION"
|
||||
bash build.sh -n $CUSTOM_OPS -c $SOC_ARG
|
||||
bash build.sh -n "$CUSTOM_OPS" -c "$SOC_ARG"
|
||||
|
||||
# install custom ops to vllm_ascend/_cann_ops_custom
|
||||
./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom
|
||||
|
||||
49
csrc/dispatch_layout/op_host/CMakeLists.txt
Normal file
49
csrc/dispatch_layout/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME DispatchLayout
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
dispatch_layout.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
dispatch_layout_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_layout.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
64
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp
Normal file
64
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_dispatch_layout.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, numTokensPerRank,
|
||||
numTokensPerExpert, isTokenInRank, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerDispatchLayout(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
50
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h
Normal file
50
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h
Normal file
@@ -0,0 +1,50 @@
|
||||
#ifndef ACLNN_DISPATCH_LAYOUT_H_
|
||||
#define ACLNN_DISPATCH_LAYOUT_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* funtion: aclnnDispatchLayoutGetWorkspaceSize
|
||||
* topkIdx : required
|
||||
* numTokens : required
|
||||
* numRanks : required
|
||||
* numExperts : required
|
||||
* numTopk : required
|
||||
* numTokensPerRank : required
|
||||
* numTokensPerExpert : required
|
||||
* isTokenInRank : required
|
||||
* workspaceSize : size of workspace(output).
|
||||
* executor : executor context(output).
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
/* funtion: aclnnDispatchLayout
|
||||
* workspace : workspace memory addr(input).
|
||||
* workspaceSize : size of workspace(input).
|
||||
* executor : executor context(input).
|
||||
* stream : acl stream.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
51
csrc/dispatch_layout/op_host/dispatch_layout.cpp
Normal file
51
csrc/dispatch_layout/op_host/dispatch_layout.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class DispatchLayout : public OpDef {
|
||||
public:
|
||||
explicit DispatchLayout(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("topkIdx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
this->Attr("num_tokens").Int();
|
||||
this->Attr("num_ranks").Int();
|
||||
this->Attr("num_experts").Int();
|
||||
this->Attr("num_topk").Int();
|
||||
|
||||
this->Output("numTokensPerRank")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("numTokensPerExpert")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("isTokenInRank")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_true")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DispatchLayout);
|
||||
} // namespace ops
|
||||
211
csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp
Normal file
211
csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp
Normal file
@@ -0,0 +1,211 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/dispatch_layout_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
#include "experiment/platform/platform/platform_infos_def.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0;
|
||||
|
||||
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1;
|
||||
constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2;
|
||||
|
||||
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0;
|
||||
constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1;
|
||||
constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2;
|
||||
constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3;
|
||||
const int64_t MAX_COMM_WORLD_SIZE = 384;
|
||||
const int64_t MAX_MOE_EXPERTS_NUM = 384;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;
|
||||
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t K_MAX = 16;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "numToken is %u.", tilingData.dispatchLayoutInfo.numTokens);
|
||||
OPS_LOG_D(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks);
|
||||
OPS_LOG_D(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts);
|
||||
OPS_LOG_D(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize);
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchLayoutTilingData &tilingData)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto numTokensPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOKENS_INDEX));
|
||||
auto numRanksPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_RANKS_INDEX));
|
||||
auto numExpertsPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_EXPERTS_INDEX);
|
||||
auto numTopkPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOPK_INDEX));
|
||||
|
||||
OPS_CHECK(numTokensPtr == nullptr, OPS_LOG_E(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numRanksPtr == nullptr, OPS_LOG_E(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numExpertsPtr == nullptr, OPS_LOG_E(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numTopkPtr == nullptr, OPS_LOG_E(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", MAX_COMM_WORLD_SIZE, *numRanksPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numExpertsPtr <= 0) || (*numExpertsPtr > MAX_MOE_EXPERTS_NUM),
|
||||
OPS_LOG_E(nodeName, "numExperts is invalid, only support (0, %ld], but got numExperts=%ld.", MAX_MOE_EXPERTS_NUM, *numExpertsPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numTopkPtr <= 0) || (*numTopkPtr > K_MAX),
|
||||
OPS_LOG_E(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tilingData.dispatchLayoutInfo.numTokens = static_cast<uint32_t>(*numTokensPtr);
|
||||
tilingData.dispatchLayoutInfo.numRanks = static_cast<uint32_t>(*numRanksPtr);
|
||||
tilingData.dispatchLayoutInfo.numExperts = static_cast<uint32_t>(*numExpertsPtr);
|
||||
tilingData.dispatchLayoutInfo.numTopk = static_cast<uint32_t>(*numTopkPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto topkIdx = context->GetInputDesc(INPUT_TOPK_IDX_INDEX);
|
||||
auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX);
|
||||
auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
|
||||
auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX);
|
||||
|
||||
OPS_CHECK(topkIdx == nullptr, OPS_LOG_E(nodeName, "topkIdx is null."), return false);
|
||||
OPS_CHECK(numTokensPerRank == nullptr, OPS_LOG_E(nodeName, "numTokensPerRank is null."), return false);
|
||||
OPS_CHECK(numTokensPerExpert == nullptr, OPS_LOG_E(nodeName, "numTokensPerExpert is null."), return false);
|
||||
OPS_CHECK(isTokenInRank == nullptr, OPS_LOG_E(nodeName, "isTokenInRank is null."), return false);
|
||||
|
||||
OPS_CHECK((topkIdx->GetDataType() != ge::DT_INT64),
|
||||
OPS_LOG_E(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(topkIdx->GetDataType())), return false);
|
||||
OPS_CHECK((numTokensPerRank->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "numTokensPerRank datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(numTokensPerRank->GetDataType())), return false);
|
||||
OPS_CHECK((numTokensPerExpert->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "numTokensPerExpert datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(numTokensPerExpert->GetDataType())), return false);
|
||||
OPS_CHECK((isTokenInRank->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(isTokenInRank->GetDataType())), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorShape(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
const gert::StorageShape *topkIdxStorageShape = context->GetInputShape(INPUT_TOPK_IDX_INDEX);
|
||||
int64_t topkIdxDim0 = topkIdxStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t topkIdxDim1 = topkIdxStorageShape->GetStorageShape().GetDim(1);
|
||||
|
||||
OPS_CHECK((topkIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS),
|
||||
OPS_LOG_E(nodeName, "topkIdx must be 2-dimension, but get %lu dim.",
|
||||
topkIdxStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckTensor(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(!CheckTensorShape(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
DispatchLayoutTilingData *tilingData = context->GetTilingData<DispatchLayoutTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func.");
|
||||
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t blockDim;
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
blockDim = aivNum;
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData->dispatchLayoutInfo.totalUbSize = ubSize;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchLayoutTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
ge::graphStatus ret;
|
||||
ret = DispatchLayoutTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct DispatchLayoutCompileInfo {};
|
||||
ge::graphStatus TilingParseForDispatchLayout(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(DispatchLayout)
|
||||
.Tiling(DispatchLayoutTilingFunc)
|
||||
.TilingParse<DispatchLayoutCompileInfo>(TilingParseForDispatchLayout);
|
||||
} // namespace optiling
|
||||
17
csrc/dispatch_layout/op_kernel/dispatch_layout.cpp
Normal file
17
csrc/dispatch_layout/op_kernel/dispatch_layout.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "dispatch_layout.h"
|
||||
#include "dispatch_layout_tiling.h"
|
||||
|
||||
|
||||
extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert,
|
||||
GM_ADDR isTokenInRank, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchLayoutTilingData);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling);
|
||||
|
||||
TPipe pipe;
|
||||
|
||||
DispatchLayout<int32_t> op;
|
||||
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, workspace, &pipe, &tilingData);
|
||||
op.Process();
|
||||
}
|
||||
153
csrc/dispatch_layout/op_kernel/dispatch_layout.h
Normal file
153
csrc/dispatch_layout/op_kernel/dispatch_layout.h
Normal file
@@ -0,0 +1,153 @@
|
||||
#ifndef DISPATCH_LAYOUT_H
|
||||
#define DISPATCH_LAYOUT_H
|
||||
|
||||
#include <climits>
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "../common/comm_args.h"
|
||||
#include "../common/data_copy.h"
|
||||
#include "../common/sync_collectives.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "dispatch_layout_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
constexpr uint32_t UB_32_ALIGN = 32U;
|
||||
constexpr uint32_t AIV_NUM = 48;
|
||||
|
||||
template <AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class DispatchLayout {
|
||||
|
||||
public:
|
||||
__aicore__ inline DispatchLayout() {};
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank,
|
||||
GM_ADDR workspace, TPipe *pipe, const DispatchLayoutTilingData *tilingData)
|
||||
{
|
||||
numTokens_ = tilingData->dispatchLayoutInfo.numTokens;
|
||||
numRanks_ = tilingData->dispatchLayoutInfo.numRanks;
|
||||
numExperts_ = tilingData->dispatchLayoutInfo.numExperts;
|
||||
numTopk_ = tilingData->dispatchLayoutInfo.numTopk;
|
||||
tpipe_ = pipe;
|
||||
|
||||
coreIdx_ = GetBlockIdx();
|
||||
uint32_t temp = numTokens_ / AIV_NUM;
|
||||
uint32_t restNum = numTokens_ % AIV_NUM;
|
||||
int64_t topkIdxOffset;
|
||||
int64_t isTokenOffset;
|
||||
tempTokens_ = temp;
|
||||
if (coreIdx_ < restNum) {
|
||||
tempTokens_++;
|
||||
}
|
||||
topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
|
||||
if (coreIdx_ < restNum) {
|
||||
topkIdxOffset = coreIdx_ * topkIdx32AlignIntLen_;
|
||||
isTokenOffset = coreIdx_ * isTokenInRank32AlignIntLen_;
|
||||
} else {
|
||||
topkIdxOffset = restNum * Ceil((tempTokens_ + 1) * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN
|
||||
+ (coreIdx_ - restNum) * topkIdx32AlignIntLen_;
|
||||
isTokenOffset = restNum * Ceil((tempTokens_ + 1) * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN
|
||||
+ (coreIdx_ - restNum) * isTokenInRank32AlignIntLen_;
|
||||
}
|
||||
|
||||
topkIdxGM_.SetGlobalBuffer((__gm__ int64_t*)(topkIdx + topkIdxOffset));
|
||||
numTokensPerRankGM_.SetGlobalBuffer((__gm__ T*)numTokensPerRank);
|
||||
numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T*)numTokensPerExpert);
|
||||
isTokenInRankGM_.SetGlobalBuffer((__gm__ T*)(isTokenInRank + isTokenOffset));
|
||||
|
||||
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_);
|
||||
tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_);
|
||||
tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_);
|
||||
tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_);
|
||||
tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T));
|
||||
|
||||
LocalTensor<int64_t> topkIdxTensor = topkIdxBuf_.AllocTensor<int64_t>();
|
||||
const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<int64_t> padParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
LocalTensor<T> numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> isTokenInRankTensor = isTokenInRankBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> seenRankTensor = seenRankBuf_.AllocTensor<T>();
|
||||
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
|
||||
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
|
||||
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
|
||||
int experts_per_rank = numExperts_ / numRanks_;
|
||||
for (int i = 0; i < tempTokens_; ++i) {
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Duplicate<T>(seenRankTensor, 0, numRanks_);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
for (int j = 0; j < numTopk_; ++j) {
|
||||
int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j);
|
||||
uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1;
|
||||
numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num);
|
||||
int rank_id = expert_idx / experts_per_rank;
|
||||
if (!seenRankTensor.GetValue(rank_id)) {
|
||||
uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1;
|
||||
isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1);
|
||||
seenRankTensor.SetValue(rank_id, 1);
|
||||
numTokensPerRankTensor.SetValue(rank_id, per_rank_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DataCopyExtParams isTokenInRankDataCopyParams{1U, isTokenInRank32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams);
|
||||
AscendC::SetAtomicAdd<T>();
|
||||
const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams);
|
||||
const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams);
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
|
||||
private:
|
||||
GlobalTensor<int64_t> topkIdxGM_;
|
||||
GlobalTensor<T> numTokensPerRankGM_;
|
||||
GlobalTensor<T> numTokensPerExpertGM_;
|
||||
GlobalTensor<T> isTokenInRankGM_;
|
||||
|
||||
TBuf<> topkIdxBuf_;
|
||||
TBuf<> numTokensPerRankBuf_;
|
||||
TBuf<> numTokensPerExpertBuf_;
|
||||
TBuf<> isTokenInRankBuf_;
|
||||
TBuf<> seenRankBuf_;
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
uint32_t numTokens_{0};
|
||||
uint32_t numRanks_{0};
|
||||
uint32_t numExperts_{0};
|
||||
uint32_t numTopk_{0};
|
||||
uint32_t coreIdx_{0};
|
||||
uint32_t tempTokens_{0};
|
||||
|
||||
uint32_t topkIdx32AlignIntLen_{0};
|
||||
uint32_t numTokensPerRank32AlignIntLen_{0};
|
||||
uint32_t numTokensPerExpert32AlignIntLen_{0};
|
||||
uint32_t isTokenInRank32AlignIntLen_{0};
|
||||
};
|
||||
|
||||
#endif // DISPATCH_LAYOUT_H
|
||||
20
csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h
Normal file
20
csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef DISPATCH_LAYOUT_TILING_H
|
||||
#define DISPATCH_LAYOUT_TILING_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
struct DispatchLayoutInfo {
|
||||
uint32_t numTokens;
|
||||
uint32_t numRanks;
|
||||
uint32_t numExperts;
|
||||
uint32_t numTopk;
|
||||
uint64_t totalUbSize;
|
||||
};
|
||||
|
||||
struct DispatchLayoutTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
DispatchLayoutInfo dispatchLayoutInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
49
csrc/moe_combine_normal/op_host/CMakeLists.txt
Normal file
49
csrc/moe_combine_normal/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MoeCombineNormal
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
moe_combine_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_moe_combine_normal.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_moe_combine_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_moe_combine_normal.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
moe_combine_normal_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_combine_normal.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
77
csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp
Normal file
77
csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.cpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_moe_combine_normal.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
extern aclnnStatus aclnnInnerMoeCombineNormalGetWorkspaceSize(
|
||||
const aclTensor *recvX,
|
||||
const aclTensor *tokenSrcInfo,
|
||||
const aclTensor *epRecvCounts,
|
||||
const aclTensor *recvTopkWeights,
|
||||
const aclTensor *tpRecvCountsOptional,
|
||||
char *epGroupName,
|
||||
int64_t epWorldSize,
|
||||
int64_t epRankId,
|
||||
char *tpGroupNameOptional,
|
||||
int64_t tpWorldSize,
|
||||
int64_t tpRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t globalBs,
|
||||
const aclTensor *out,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerMoeCombineNormal(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize(
|
||||
const aclTensor *recvX,
|
||||
const aclTensor *tokenSrcInfo,
|
||||
const aclTensor *epRecvCounts,
|
||||
const aclTensor *recvTopkWeights,
|
||||
const aclTensor *tpRecvCountsOptional,
|
||||
char *epGroupName,
|
||||
int64_t epWorldSize,
|
||||
int64_t epRankId,
|
||||
char *tpGroupNameOptional,
|
||||
int64_t tpWorldSize,
|
||||
int64_t tpRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t globalBs,
|
||||
const aclTensor *out,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerMoeCombineNormalGetWorkspaceSize(recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights,
|
||||
tpRecvCountsOptional, epGroupName, epWorldSize, epRankId,
|
||||
tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum,
|
||||
globalBs, out, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeCombineNormal(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerMoeCombineNormal(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
62
csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h
Normal file
62
csrc/moe_combine_normal/op_host/aclnn_moe_combine_normal.h
Normal file
@@ -0,0 +1,62 @@
|
||||
#ifndef ACLNN_MOE_COMBINE_NORMAL_H_
|
||||
#define ACLNN_MOE_COMBINE_NORMAL_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* funtion: aclnnMoeCombineGetWorkspaceSize
|
||||
* recvX : required
|
||||
* tokenSrcInfo : required
|
||||
* epRecvCounts : required
|
||||
* recvTopkWeights : required
|
||||
* tpRecvCountsOptional : required
|
||||
* epGroupName : optional
|
||||
* epWorldSize : required
|
||||
* epRankId : required
|
||||
* tpGroupNameOptional : required
|
||||
* tpWorldSize : optional
|
||||
* tpRankId : optional
|
||||
* moeExpertNum : optional
|
||||
* globalBs : optional
|
||||
* out : required
|
||||
* workspaceSize : size of workspace(output).
|
||||
* executor : executor context(output).
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize(
|
||||
const aclTensor *recvX,
|
||||
const aclTensor *tokenSrcInfo,
|
||||
const aclTensor *epRecvCounts,
|
||||
const aclTensor *recvTopkWeights,
|
||||
const aclTensor *tpRecvCountsOptional,
|
||||
char *epGroupName,
|
||||
int64_t epWorldSize,
|
||||
int64_t epRankId,
|
||||
char *tpGroupNameOptional,
|
||||
int64_t tpWorldSize,
|
||||
int64_t tpRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t globalBs,
|
||||
const aclTensor *out,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
/* funtion: aclnnMoeCombine
|
||||
* workspace : workspace memory addr(input).
|
||||
* workspaceSize : size of workspace(input).
|
||||
* executor : executor context(input).
|
||||
* stream : acl stream.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormal(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
71
csrc/moe_combine_normal/op_host/moe_combine_normal.cpp
Normal file
71
csrc/moe_combine_normal/op_host/moe_combine_normal.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class MoeCombineNormal : public OpDef {
|
||||
public:
|
||||
explicit MoeCombineNormal(const char* name) : OpDef(name) {
|
||||
this->Input("recv_x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("token_src_info")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("ep_recv_counts")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("recv_topk_weights")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("tp_recv_counts")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Output("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("ep_group_name").AttrType(REQUIRED).String();
|
||||
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
|
||||
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();
|
||||
this->Attr("tp_group_name").AttrType(OPTIONAL).String("");
|
||||
this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("moe_expert_num").AttrType(REQUIRED).Int();
|
||||
this->Attr("global_bs").AttrType(OPTIONAL).Int(0);
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_true")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->MC2().HcclGroup({"ep_group_name", "tp_group_name"});
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(MoeCombineNormal);
|
||||
|
||||
} // namespace ops
|
||||
546
csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp
Normal file
546
csrc/moe_combine_normal/op_host/moe_combine_normal_tiling.cpp
Normal file
@@ -0,0 +1,546 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/moe_combine_normal_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace ge;
|
||||
|
||||
namespace {
|
||||
class Mc2TilingUtils {
|
||||
public:
|
||||
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
|
||||
static uint64_t GetMaxWindowSize()
|
||||
{
|
||||
uint16_t defaultWindowSize = 200;
|
||||
if (getenv(HCCL_BUFFSIZE) == nullptr) {
|
||||
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
|
||||
} else {
|
||||
try {
|
||||
std::string envStr(getenv(HCCL_BUFFSIZE));
|
||||
defaultWindowSize = std::stoi(envStr);
|
||||
} catch (...) {
|
||||
OPS_LOG_E("", "Unknown Exception encountered when parser env HCCL_BUFFERSIZE");
|
||||
}
|
||||
}
|
||||
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
|
||||
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
|
||||
return maxWindowSize;
|
||||
}
|
||||
};
|
||||
constexpr uint32_t RECV_X_INDEX = 0;
|
||||
constexpr uint32_t TOKEN_SRC_INFO_INDEX = 1;
|
||||
constexpr uint32_t EP_RECV_COUNTS_INDEX = 2;
|
||||
constexpr uint32_t TOPK_WEIGHTS_INDEX = 3;
|
||||
constexpr uint32_t TP_RECV_COUNTS_INDEX = 4;
|
||||
constexpr uint32_t OUTPUT_X_INDEX = 0;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_GROUP_TP_INDEX = 3;
|
||||
constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4;
|
||||
constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
|
||||
|
||||
constexpr uint32_t TWO_DIMS = 2U;
|
||||
constexpr uint32_t ONE_DIM = 1U;
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll
|
||||
constexpr uint32_t OP_TYPE_REDUCE_SCATTER = 7U; // numeric representation of ReduceScatter
|
||||
|
||||
constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
|
||||
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
|
||||
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
|
||||
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
|
||||
constexpr int64_t BS_UPPER_BOUND = 8000;
|
||||
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes
|
||||
constexpr int64_t MOE_EXPERT_MAX_NUM = 512;
|
||||
constexpr int64_t K_MAX = 16;
|
||||
constexpr int64_t H_MIN = 1024;
|
||||
constexpr int64_t H_MAX = 7168;
|
||||
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
|
||||
constexpr uint64_t TRIPLE = 3;
|
||||
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
|
||||
constexpr uint64_t SCALE_RECV_IDX_BUFFER = 44UL; // scale32B + 3*4 src info
|
||||
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL;
|
||||
constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
|
||||
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
|
||||
constexpr uint64_t UB_ALIGN = 32UL;
|
||||
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
|
||||
|
||||
enum class CommQuantMode : int32_t {
|
||||
NON_QUANT = 0,
|
||||
INT12_QUANT = 1,
|
||||
INT8_QUANT = 2
|
||||
};
|
||||
using CommQuantModeType = std::underlying_type<CommQuantMode>;
|
||||
}
|
||||
|
||||
namespace optiling {
|
||||
|
||||
// Specific to A3
|
||||
static void PrintTilingDataInfo(const char *nodeName, MoeCombineNormalTilingData& tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeCombineNormalInfo.epWorldSize);
|
||||
OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeCombineNormalInfo.tpWorldSize);
|
||||
OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeCombineNormalInfo.epRankId);
|
||||
OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeCombineNormalInfo.tpRankId);
|
||||
OPS_LOG_D(nodeName, "expertShardType is %u.", tilingData.moeCombineNormalInfo.expertShardType);
|
||||
OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeCombineNormalInfo.moeExpertNum);
|
||||
OPS_LOG_D(nodeName, "moeExpertPerRankNum is %u.", tilingData.moeCombineNormalInfo.moeExpertPerRankNum);
|
||||
OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeCombineNormalInfo.globalBs);
|
||||
OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeCombineNormalInfo.bs);
|
||||
OPS_LOG_D(nodeName, "k is %u.", tilingData.moeCombineNormalInfo.k);
|
||||
OPS_LOG_D(nodeName, "h is %u.", tilingData.moeCombineNormalInfo.h);
|
||||
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeCombineNormalInfo.aivNum);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeCombineNormalInfo.totalUbSize);
|
||||
OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeCombineNormalInfo.totalWinSize);
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
|
||||
const char *nodeName, std::string &groupEp, std::string &groupTp)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
|
||||
auto groupTpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_TP_INDEX));
|
||||
auto epWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_WORLD_SIZE_INDEX);
|
||||
auto tpWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_WORLD_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto tpRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
|
||||
// Check for null
|
||||
OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), OPS_LOG_E(nodeName, "groupEp is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSize is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSize is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankId is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankId is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNum is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if it meets uint32_t and other constraints
|
||||
int64_t moeExpertNum = *moeExpertNumPtr;
|
||||
int64_t epWorldSize = *epWorldSizePtr;
|
||||
OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName, "epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.",
|
||||
MIN_EP_WORLD_SIZE, MAX_EP_WORLD_SIZE, epWorldSize), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName, "tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.",
|
||||
MAX_TP_WORLD_SIZE, *tpWorldSizePtr), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize),
|
||||
OPS_LOG_E(nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.",
|
||||
epWorldSize, *epRankIdPtr), return ge::GRAPH_FAILED);
|
||||
|
||||
if (*tpWorldSizePtr > 1) {
|
||||
OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr),
|
||||
OPS_LOG_E(nodeName, "tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.",
|
||||
*tpWorldSizePtr, *tpRankIdPtr), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "groupTpPtr is null."), return ge::GRAPH_FAILED);
|
||||
groupTp = std::string(groupTpPtr);
|
||||
} else {
|
||||
OPS_CHECK(*tpRankIdPtr != 0,
|
||||
OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM),
|
||||
OPS_LOG_E(nodeName, "moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.",
|
||||
MOE_EXPERT_MAX_NUM, moeExpertNum), return ge::GRAPH_FAILED);
|
||||
int64_t moePerRankNum = moeExpertNum / epWorldSize;
|
||||
int64_t curDispatchStatusNum = moePerRankNum * epWorldSize;
|
||||
OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM),
|
||||
OPS_LOG_E(nodeName, "The moe experts num must meet the conditions,"
|
||||
" (moeExpertNum / epWorldSize) * epWorldSize <= 1280, but cur is %ld.",
|
||||
curDispatchStatusNum), return ge::GRAPH_FAILED);
|
||||
|
||||
groupEp = std::string(groupEpPtr);
|
||||
tilingData.moeCombineNormalInfo.epWorldSize = static_cast<uint32_t>(epWorldSize);
|
||||
tilingData.moeCombineNormalInfo.tpWorldSize = static_cast<uint32_t>(*tpWorldSizePtr);
|
||||
tilingData.moeCombineNormalInfo.epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
tilingData.moeCombineNormalInfo.tpRankId = static_cast<uint32_t>(*tpRankIdPtr);
|
||||
tilingData.moeCombineNormalInfo.moeExpertNum = static_cast<uint32_t>(moeExpertNum);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static bool CheckInputTensorDim(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX);
|
||||
OPS_CHECK(recvXStorageShape == nullptr, OPS_LOG_E(nodeName, "recvX is null."), return false);
|
||||
OPS_CHECK(recvXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "recvX must be 2-dimension, but got %lu dim",
|
||||
recvXStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
OPS_LOG_D(nodeName, "recvX dim0 = %ld", recvXStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "recvX dim1 = %ld", recvXStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
const gert::StorageShape *tokenSrcInfoStorageShape = context->GetInputShape(TOKEN_SRC_INFO_INDEX);
|
||||
OPS_CHECK(tokenSrcInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoForCombine is null."), return false);
|
||||
OPS_CHECK(tokenSrcInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
|
||||
OPS_LOG_E(nodeName, "tokenSrcInfoForCombine must be 1-dimension, but got %lu dim",
|
||||
tokenSrcInfoStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
OPS_LOG_D(nodeName, "tokenSrcInfoForCombine dim0 = %ld", tokenSrcInfoStorageShape->GetStorageShape().GetDim(0));
|
||||
|
||||
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
|
||||
OPS_CHECK(topkWeightsStorageShape == nullptr, OPS_LOG_E(nodeName, "topkWeights is null."), return false);
|
||||
OPS_CHECK(topkWeightsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "topkWeights must be 2-dimension, but got %lu dim",
|
||||
topkWeightsStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
OPS_LOG_D(nodeName, "topkWeights dim0 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "topkWeights dim1 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
const gert::StorageShape *tpRecvCountsStorageShape = context->GetOptionalInputShape(TP_RECV_COUNTS_INDEX);
|
||||
OPS_CHECK(tpRecvCountsStorageShape == nullptr, OPS_LOG_E(nodeName, "tpRecvCounts is null."), return false);
|
||||
OPS_CHECK(tpRecvCountsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
|
||||
OPS_LOG_E(nodeName, "tpRecvCounts must be 1-dimension, but got %lu dim",
|
||||
tpRecvCountsStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
OPS_LOG_D(nodeName, "tpRecvCounts dim0 = %ld", tpRecvCountsStorageShape->GetStorageShape().GetDim(0));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x is null."), return false);
|
||||
OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "x must be 2-dimension, but got %lu dim", xStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
OPS_CHECK(!CheckInputTensorDim(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param shape of input tensor is invalid"), return false);
|
||||
|
||||
OPS_CHECK(!CheckOptionalInputTensorDim(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param shape of optional input tensor is invalid"), return false);
|
||||
|
||||
OPS_CHECK(!CheckOutputTensorDim(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param shape of output tensor is invalid"), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Validate data type
|
||||
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
|
||||
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false);
|
||||
OPS_CHECK((recvXDesc->GetDataType() != ge::DT_BF16) && (recvXDesc->GetDataType() != ge::DT_FLOAT16),
|
||||
OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is "
|
||||
), return false);
|
||||
auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX);
|
||||
OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false);
|
||||
OPS_CHECK((tokenSrcInfoDesc->GetDataType() != ge::DT_INT32), OPS_LOG_E(nodeName, "tokenSrcInfoForCombine dataType is invalid,"
|
||||
" dataType should be int32, but is"), return false);
|
||||
auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX);
|
||||
OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false);
|
||||
OPS_CHECK((tpRecvCountsDesc->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "tpRecvCounts dataType is invalid, dataType should be int32, but is "), return false);
|
||||
auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX);
|
||||
OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false);
|
||||
OPS_CHECK((topkWeightsDesc->GetDataType() != ge::DT_FLOAT),
|
||||
OPS_LOG_E(nodeName, "topkWeights dataType is invalid, dataType should be float, but is "),
|
||||
return false);
|
||||
auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()), OPS_LOG_E(nodeName,
|
||||
"x dataType is invalid, dataType should be equal to recvX dataType , but is "),
|
||||
return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
|
||||
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(recvXDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "recvXFormat is invalid"), return false);
|
||||
|
||||
auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX);
|
||||
OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(tokenSrcInfoDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tokenSrcInfoFormat is invalid"), return false);
|
||||
|
||||
auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX);
|
||||
OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(tpRecvCountsDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tpRecvCountsFormat is invalid"), return false);
|
||||
|
||||
auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX);
|
||||
OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(topkWeightsDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "topkWeightsFormat is invalid"), return false);
|
||||
|
||||
auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "xFormat is invalid"), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorShape(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
|
||||
const char *nodeName, uint32_t localExpertNum)
|
||||
{
|
||||
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
|
||||
int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t topkWeightsDim1 = topkWeightsStorageShape->GetStorageShape().GetDim(1);
|
||||
int64_t moeExpertNum = static_cast<int64_t>(tilingData.moeCombineNormalInfo.moeExpertNum);
|
||||
OPS_CHECK((topkWeightsDim1 <= 0) || (topkWeightsDim1 > K_MAX || (topkWeightsDim1 > moeExpertNum)),
|
||||
OPS_LOG_E(nodeName, "topkWeights's dim1(K) should be in (0, min(%ld, moeExpertNum %ld)], "
|
||||
"but got topkWeights's dim1=%ld.", K_MAX, moeExpertNum, topkWeightsDim1), return false);
|
||||
tilingData.moeCombineNormalInfo.k = static_cast<uint32_t>(topkWeightsDim1);
|
||||
|
||||
// Validate recvX dimensions and set h
|
||||
int64_t tpWorldSize = static_cast<int64_t>(tilingData.moeCombineNormalInfo.tpWorldSize);
|
||||
const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX);
|
||||
int64_t recvXDim1 = recvXStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK((recvXDim1 < H_MIN) || (recvXDim1 > H_MAX),
|
||||
OPS_LOG_E(nodeName, "recvX's dim1(H) should be in [%ld, %ld], but got %ld.",
|
||||
H_MIN, H_MAX, recvXDim1), return false); // 32-byte aligned
|
||||
tilingData.moeCombineNormalInfo.h = static_cast<uint32_t>(recvXDim1);
|
||||
|
||||
// Validate epRecvCount and tpRecvCount dimensions
|
||||
int64_t epWorldSize = static_cast<int64_t>(tilingData.moeCombineNormalInfo.epWorldSize);
|
||||
int64_t moeExpertPerRankNum = static_cast<int64_t>(tilingData.moeCombineNormalInfo.moeExpertPerRankNum);
|
||||
|
||||
// Validate x dimensions
|
||||
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK(xDim0 != topkWeightsDim0, OPS_LOG_E(nodeName,
|
||||
"x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false);
|
||||
OPS_CHECK(xDim1 != recvXDim1, OPS_LOG_E(nodeName,
|
||||
"x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckAttrs(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
|
||||
const char *nodeName, uint32_t &localMoeExpertNum)
|
||||
{
|
||||
uint32_t epWorldSize = tilingData.moeCombineNormalInfo.epWorldSize;
|
||||
uint32_t tpWorldSize = tilingData.moeCombineNormalInfo.tpWorldSize;
|
||||
uint32_t moeExpertNum = tilingData.moeCombineNormalInfo.moeExpertNum;
|
||||
|
||||
// Validate if moe expert number can be evenly distributed across multiple machines
|
||||
OPS_CHECK(moeExpertNum % epWorldSize != 0,
|
||||
OPS_LOG_E(nodeName, "moeExpertNum should be divisible by epWorldSize, "
|
||||
"but got moeExpertNum=%d, epWorldSize=%d.", moeExpertNum, epWorldSize), return false);
|
||||
localMoeExpertNum = moeExpertNum / epWorldSize;
|
||||
OPS_CHECK(localMoeExpertNum <= 0,
|
||||
OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum), return false);
|
||||
// Validate if expert number per card equals 1 when tp=2
|
||||
OPS_CHECK((localMoeExpertNum > 1) && (tpWorldSize > 1),
|
||||
OPS_LOG_E(nodeName, "Cannot support multi-moeExpert %d in a rank when tpWorldSize = %d > 1",
|
||||
localMoeExpertNum, tpWorldSize), return false);
|
||||
tilingData.moeCombineNormalInfo.moeExpertPerRankNum = localMoeExpertNum;
|
||||
|
||||
// Validate topkWeights dimension 0 and set bs
|
||||
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
|
||||
int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_CHECK((topkWeightsDim0 <= 0) || (topkWeightsDim0 > BS_UPPER_BOUND),
|
||||
OPS_LOG_E(nodeName, "Invalid topkWeights dims0(BS) %ld. Should be between [1, %ld].",
|
||||
topkWeightsDim0, BS_UPPER_BOUND), return false);
|
||||
tilingData.moeCombineNormalInfo.bs = static_cast<uint32_t>(topkWeightsDim0);
|
||||
|
||||
// Validate globalBS
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return false);
|
||||
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
|
||||
OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBs is null."), return false);
|
||||
OPS_LOG_D(nodeName, "MoeCombineNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n",
|
||||
*globalBsPtr, topkWeightsDim0, epWorldSize);
|
||||
|
||||
OPS_CHECK((*globalBsPtr != 0) && ((*globalBsPtr < static_cast<int64_t>(epWorldSize) * topkWeightsDim0) ||
|
||||
((*globalBsPtr) % (static_cast<int64_t>(epWorldSize)) != 0)), OPS_LOG_E(nodeName, "globalBS is invalid, only "
|
||||
"support 0 or maxBs(maxBs is the largest bs on all ranks) * epWorldSize, but got globalBS=%ld, "
|
||||
"bs=%ld, epWorldSize=%u.", *globalBsPtr, topkWeightsDim0, epWorldSize), return false);
|
||||
|
||||
tilingData.moeCombineNormalInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
|
||||
if (*globalBsPtr == 0) {
|
||||
tilingData.moeCombineNormalInfo.globalBs = static_cast<uint32_t>(topkWeightsDim0) * epWorldSize;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckMoeCombineNormal(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
// Check parameter shape information
|
||||
OPS_CHECK(!CheckTensorDim(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param shape is invalid"), return ge::GRAPH_FAILED);
|
||||
// Check parameter dataType information
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param dataType is invalid"), return ge::GRAPH_FAILED);
|
||||
// Check parameter format information
|
||||
OPS_CHECK(!CheckTensorFormat(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "param Format is invalid"), return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkspace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workspace = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workspace == nullptr, OPS_LOG_E(nodeName, "get workspace failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
workspace[0] = SYSTEM_NEED_WORKSPACE;
|
||||
OPS_LOG_D(nodeName, "workspce[0] size is %ld", workspace[0]);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
static void SetHCommCfg(gert::TilingContext *context, MoeCombineNormalTilingData *tiling,
|
||||
const std::string groupEp, const std::string groupTp)
|
||||
{
|
||||
const char* nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "MoeCombineNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str());
|
||||
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
|
||||
uint32_t opType2 = OP_TYPE_REDUCE_SCATTER;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
std::string algConfigReduceScatterStr = "ReduceScatter=level0:ring";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
|
||||
|
||||
mc2CcTilingConfig.SetGroupName(groupTp);
|
||||
mc2CcTilingConfig.SetOpType(opType2);
|
||||
mc2CcTilingConfig.SetAlgConfig(algConfigReduceScatterStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2);
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeCombineNormalA3TilingFuncImpl(gert::TilingContext* context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "Enter MoeCombineNormal Tiling func");
|
||||
MoeCombineNormalTilingData *tilingData = context->GetTilingData<MoeCombineNormalTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string groupEp = "";
|
||||
std::string groupTp = "";
|
||||
uint32_t localMoeExpertNum = 1;
|
||||
|
||||
// Get input parameter attributes
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED,
|
||||
OPS_LOG_E(nodeName, "Getting attr failed."), return ge::GRAPH_FAILED);
|
||||
|
||||
// Check input/output dim, format, dataType
|
||||
OPS_CHECK(TilingCheckMoeCombineNormal(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check params failed"), return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if attribute values are valid
|
||||
OPS_CHECK(!CheckAttrs(context, *tilingData, nodeName, localMoeExpertNum),
|
||||
OPS_LOG_E(nodeName, "attr check failed."), return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t epRankId = tilingData->moeCombineNormalInfo.epRankId;
|
||||
|
||||
// Check shape dimensions and assign h, k
|
||||
OPS_CHECK(!CheckTensorShape(context, *tilingData, nodeName, localMoeExpertNum),
|
||||
OPS_LOG_E(nodeName, "param dim check failed."), return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate win area size
|
||||
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
|
||||
uint64_t h = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.h);
|
||||
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.epWorldSize);
|
||||
uint64_t k = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.k);
|
||||
uint64_t maxBs = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.globalBs)/ epWorldSize;
|
||||
// Combine data area: token start address aligned to 512
|
||||
uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
// Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b)
|
||||
uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_RECV_IDX_BUFFER;
|
||||
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET) *
|
||||
DOUBLE_DATA_BUFFER;
|
||||
OPS_CHECK((actualSize > maxWindowSize),
|
||||
OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
|
||||
" tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE("
|
||||
"((maxBs * tokenNeedSizeDispatch) + (maxBs * tokenNeedSizeCombine * k) + 3MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.",
|
||||
maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k,
|
||||
actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData->moeCombineNormalInfo.totalWinSize = maxWindowSize;
|
||||
|
||||
OPS_CHECK(SetWorkspace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(context->GetNodeName(), "Tiling set workspace Failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
SetHCommCfg(context, tilingData, groupEp, groupTp);
|
||||
|
||||
uint64_t tpWorldSize = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.tpWorldSize);
|
||||
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint64_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData->moeCombineNormalInfo.aivNum = aivNum;
|
||||
tilingData->moeCombineNormalInfo.totalUbSize = ubSize;
|
||||
context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously
|
||||
OPS_LOG_D(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeCombineNormalTilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
// recvX data type int32 is not supported
|
||||
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return ge::GRAPH_FAILED);
|
||||
// Check if recvX data type is DT_INT32
|
||||
OPS_CHECK((recvXDesc->GetDataType() == ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is "),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
ge::graphStatus ret = MoeCombineNormalA3TilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct MoeCombineNormalCompileInfo {};
|
||||
ge::graphStatus TilingParseForMoeCombineNormal(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(MoeCombineNormal)
|
||||
.Tiling(MoeCombineNormalTilingFunc)
|
||||
.TilingParse<MoeCombineNormalCompileInfo>(TilingParseForMoeCombineNormal);
|
||||
} // namespace optiling
|
||||
22
csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp
Normal file
22
csrc/moe_combine_normal/op_kernel/moe_combine_normal.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "moe_combine_normal.h"
|
||||
#include "moe_combine_normal_tiling.h"
|
||||
using namespace AscendC;
|
||||
using namespace MoeCombineNormalImpl;
|
||||
|
||||
extern "C" __global__ __aicore__ void moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount,
|
||||
GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut,
|
||||
GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(MoeCombineNormalTilingData);
|
||||
TPipe pipe;
|
||||
|
||||
#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeCombineNormalTilingData, tilingData, tilingGM);
|
||||
MoeCombineNormal<DTYPE_RECV_X, DTYPE_X, int32_t> op;
|
||||
op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, workspaceGM, &pipe, &tilingData);
|
||||
op.Process();
|
||||
#endif
|
||||
}
|
||||
377
csrc/moe_combine_normal/op_kernel/moe_combine_normal.h
Normal file
377
csrc/moe_combine_normal/op_kernel/moe_combine_normal.h
Normal file
@@ -0,0 +1,377 @@
|
||||
#ifndef MOE_COMBINE_NORMAL_H
|
||||
#define MOE_COMBINE_NORMAL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_combine_normal_tiling.h"
|
||||
|
||||
namespace MoeCombineNormalImpl {
|
||||
constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U;
|
||||
constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U;
|
||||
constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U;
|
||||
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL;
|
||||
constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL;
|
||||
constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U;
|
||||
constexpr uint32_t UB_32_ALIGN = 32U;
|
||||
constexpr uint32_t MUL_256_ALIGN = 256U;
|
||||
constexpr uint64_t WIN_512_ALIGN = 512UL;
|
||||
constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U;
|
||||
constexpr uint8_t DOUBLE_BUFFER = 2;
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
#define TemplateMC2TypeClass typename RecvXType, typename XType, typename SrcInfoType
|
||||
#define TemplateMC2TypeFunc RecvXType, XType, SrcInfoType
|
||||
|
||||
using namespace AscendC;
|
||||
template <TemplateMC2TypeClass>
|
||||
class MoeCombineNormal {
|
||||
public:
|
||||
__aicore__ inline MoeCombineNormal() {};
|
||||
__aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights,
|
||||
GM_ADDR tpRecvCount,GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
||||
const MoeCombineNormalTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
private:
|
||||
__aicore__ inline void InitMagic();
|
||||
__aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount,
|
||||
GM_ADDR topkWeights, GM_ADDR XOut);
|
||||
__aicore__ inline void InitTilingData(const MoeCombineNormalTilingData *tilingData);
|
||||
__aicore__ inline void InitBuffLen();
|
||||
__aicore__ inline void CopyBufferToShareAndSetStatus();
|
||||
__aicore__ inline void CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex);
|
||||
__aicore__ inline void ReadBufferFromRemote();
|
||||
__aicore__ inline void WaitBuffCopy(uint32_t tokenIndex);
|
||||
__aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId);
|
||||
__aicore__ inline void ReadBufferAndWeightedSum(uint32_t tokenIndex, uint32_t startTokenIndex);
|
||||
|
||||
__aicore__ GM_ADDR GetStateAddrByRankId(const int32_t rankId)
|
||||
{
|
||||
GM_ADDR bufferAddr;
|
||||
if (epRankId_ == rankId) {
|
||||
bufferAddr = (GM_ADDR)epWinContext_->localWindowsIn;
|
||||
} else {
|
||||
bufferAddr = (GM_ADDR)((HcclRankRelationResV2 *)epWinContext_->remoteRes[rankId].nextDevicePtr)->windowsIn;
|
||||
}
|
||||
return (GM_ADDR)(bufferAddr + winDataSizeOffset_);
|
||||
}
|
||||
|
||||
__aicore__ GM_ADDR GetBufferAddrByRankId(const int32_t rankId)
|
||||
{
|
||||
return GetStateAddrByRankId(rankId) + COMBINE_STATE_WIN_OFFSET;
|
||||
}
|
||||
|
||||
__aicore__ inline void SplitCoreCal(uint32_t totalNum, uint32_t &perCoreNum, uint32_t &startIdx, uint32_t &endIdx)
|
||||
{
|
||||
perCoreNum = totalNum / aivNum_;
|
||||
uint32_t remainderRankNum = totalNum % aivNum_;
|
||||
|
||||
startIdx = perCoreNum * coreIdx_;
|
||||
if (coreIdx_ < remainderRankNum) {
|
||||
perCoreNum++;
|
||||
startIdx += coreIdx_;
|
||||
} else {
|
||||
startIdx += remainderRankNum;
|
||||
}
|
||||
endIdx = startIdx + perCoreNum;
|
||||
}
|
||||
|
||||
__gm__ HcclOpResParam *epWinContext_{nullptr};
|
||||
__gm__ HcclOpResParam *tpWinContext_{nullptr};
|
||||
uint32_t axisBS_{0};
|
||||
uint32_t axisH_{0};
|
||||
uint32_t axisK_{0};
|
||||
uint32_t aivNum_{0};
|
||||
uint32_t epWorldSize_{0};
|
||||
uint32_t epRankId_{0};
|
||||
uint32_t coreIdx_{0};
|
||||
uint32_t moeExpertNum_{0};
|
||||
uint32_t moeExpertPerRankNum_{0};
|
||||
uint32_t magic_{0};
|
||||
uint64_t winDataSizeOffset_{0};
|
||||
uint32_t selfSendCnt_{0};
|
||||
uint32_t hRecvXTypeLen_{0};
|
||||
uint32_t h32AlignFloatLen_{0};
|
||||
uint32_t h256AlignFloatLen_{0};
|
||||
uint32_t h32AlignRecvXLen_{0};
|
||||
uint32_t h512AlignRecvXLen_{0};
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
TQue<QuePosition::VECIN, 1> weightedSumQueue_;
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> localCopyQueue_;
|
||||
TBuf<> stateBuf_;
|
||||
TBuf<> topkWeightsBuf_;
|
||||
TBuf<> tokenFloatBuf_;
|
||||
TBuf<> sumFloatBuf_;
|
||||
TBuf<> weightedMulBuf_;
|
||||
TBuf<> srcInfoBuf_;
|
||||
TBuf<> xOutBuf_;
|
||||
TBuf<> tempStateBuf_;
|
||||
|
||||
GlobalTensor<RecvXType> recvXGM_;
|
||||
GlobalTensor<SrcInfoType> tokenSrcInfoGM_;
|
||||
GlobalTensor<SrcInfoType> epRecvCountGM_;
|
||||
GlobalTensor<float> topkWeightsGM_;
|
||||
GlobalTensor<XType> xOutGlobal_;
|
||||
GM_ADDR localRankGM_;
|
||||
GM_ADDR workspaceGM_;
|
||||
};
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitMagic()
|
||||
{
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
epWinContext_ = (__gm__ HcclOpResParam*)contextGM0;
|
||||
|
||||
GlobalTensor<int32_t> selfMagicTensor;
|
||||
selfMagicTensor.SetGlobalBuffer((__gm__ int32_t*)((GM_ADDR)epWinContext_->localWindowsExp + MAGIC_WIN_OFFSET +
|
||||
coreIdx_ * WIN_512_ALIGN));
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfMagicTensor);
|
||||
magic_ = selfMagicTensor(0);
|
||||
selfMagicTensor(0) = ((magic_ == 0) ? 1 : 0);
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfMagicTensor);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(
|
||||
GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR XOut)
|
||||
{
|
||||
recvXGM_.SetGlobalBuffer((__gm__ RecvXType*)recvX);
|
||||
tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType*)tokenSrcInfo);
|
||||
epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t*)epRecvCount);
|
||||
topkWeightsGM_.SetGlobalBuffer((__gm__ float*)topkWeights);
|
||||
xOutGlobal_.SetGlobalBuffer((__gm__ XType*)XOut);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitTilingData(const MoeCombineNormalTilingData *tilingData)
|
||||
{
|
||||
axisBS_ = tilingData->moeCombineNormalInfo.bs;
|
||||
axisH_ = tilingData->moeCombineNormalInfo.h;
|
||||
axisK_ = tilingData->moeCombineNormalInfo.k;
|
||||
aivNum_ = tilingData->moeCombineNormalInfo.aivNum;
|
||||
moeExpertNum_ = tilingData->moeCombineNormalInfo.moeExpertNum;
|
||||
moeExpertPerRankNum_ = tilingData->moeCombineNormalInfo.moeExpertPerRankNum;
|
||||
epWorldSize_ = tilingData->moeCombineNormalInfo.epWorldSize;
|
||||
epRankId_ = tilingData->moeCombineNormalInfo.epRankId;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitBuffLen()
|
||||
{
|
||||
uint32_t hFloatSize = axisH_ * static_cast<uint32_t>(sizeof(float));
|
||||
h32AlignFloatLen_ = Ceil(hFloatSize, UB_32_ALIGN) * UB_32_ALIGN;
|
||||
h256AlignFloatLen_ = Ceil(hFloatSize, MUL_256_ALIGN) * MUL_256_ALIGN;
|
||||
hRecvXTypeLen_ = axisH_ * sizeof(RecvXType);
|
||||
h32AlignRecvXLen_ = Ceil(hRecvXTypeLen_, UB_32_ALIGN) * UB_32_ALIGN;
|
||||
h512AlignRecvXLen_ = Ceil(hRecvXTypeLen_, WIN_512_ALIGN) * WIN_512_ALIGN;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo,
|
||||
GM_ADDR epRecvCount, GM_ADDR topkWeights,
|
||||
GM_ADDR tpRecvCount, GM_ADDR XOut,
|
||||
GM_ADDR workspaceGM, TPipe *pipe,
|
||||
const MoeCombineNormalTilingData *tilingData)
|
||||
{
|
||||
workspaceGM_ = workspaceGM;
|
||||
tpipe_ = pipe;
|
||||
coreIdx_ = GetBlockIdx();
|
||||
|
||||
InitMagic();
|
||||
InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, XOut);
|
||||
InitTilingData(tilingData);
|
||||
InitBuffLen();
|
||||
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
winDataSizeOffset_ = static_cast<uint64_t>(magic_) * (tilingData->moeCombineNormalInfo.totalWinSize / 2UL);
|
||||
localRankGM_ = GetBufferAddrByRankId(epRankId_);
|
||||
DataCacheCleanAndInvalid<SrcInfoType, CacheLine::SINGLE_CACHE_LINE,
|
||||
DcciDst::CACHELINE_OUT>(epRecvCountGM_[moeExpertNum_ - 1]);
|
||||
selfSendCnt_ = epRecvCountGM_(moeExpertNum_ - 1);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShareAndSetStatus()
|
||||
{
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
uint32_t perBlockSendNum = 0, startTokenId = 0, endTokenId = 0;
|
||||
SplitCoreCal(selfSendCnt_, perBlockSendNum, startTokenId, endTokenId);
|
||||
if (perBlockSendNum == 0U) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t blockLen = static_cast<uint32_t>(perBlockSendNum * TOKEN_SRC_INFO_LEN * sizeof(uint32_t));
|
||||
tpipe_->Reset();
|
||||
tpipe_->InitBuffer(stateBuf_, UB_32_ALIGN);
|
||||
tpipe_->InitBuffer(localCopyQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_);
|
||||
tpipe_->InitBuffer(srcInfoBuf_, blockLen);
|
||||
LocalTensor<uint32_t> statusTensor = stateBuf_.AllocTensor<uint32_t>();
|
||||
Duplicate<uint32_t>(statusTensor, 0x3F800000, FLOAT_NUM_PER_ALIGN);
|
||||
|
||||
LocalTensor<SrcInfoType> srcInfoLocal = srcInfoBuf_.Get<SrcInfoType>();
|
||||
const DataCopyExtParams dataCopyParams{1U, blockLen, 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<SrcInfoType> padParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(srcInfoLocal, tokenSrcInfoGM_[startTokenId * TOKEN_SRC_INFO_LEN], dataCopyParams, padParams);
|
||||
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) {
|
||||
uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN;
|
||||
uint32_t srcRankId = static_cast<uint32_t>(srcInfoLocal(index + RANK_ID_OFFSET_IN_SRC_INFO));
|
||||
uint32_t srcTokenId = static_cast<uint32_t>(srcInfoLocal(index + TOKEN_IDX_OFFSET_IN_SRC_INFO));
|
||||
uint32_t srcTopkId = static_cast<uint32_t>(srcInfoLocal(index + TOPK_IDX_OFFSET_IN_SRC_INFO));
|
||||
CopyBufferToShare(srcRankId, srcTokenId, srcTopkId, tokenIndex);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId);
|
||||
}
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId,
|
||||
uint32_t srcTopkId, uint32_t tkIndex)
|
||||
{
|
||||
uint32_t tokenOffset = tkIndex * axisH_;
|
||||
GM_ADDR dstGM = GetBufferAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_;
|
||||
GlobalTensor<XType> dstWindow;
|
||||
dstWindow.SetGlobalBuffer((__gm__ XType*)dstGM);
|
||||
DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
|
||||
LocalTensor<RecvXType> localCopyTensor;
|
||||
localCopyTensor = localCopyQueue_.AllocTensor<RecvXType>();
|
||||
DataCopyPad(localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams);
|
||||
localCopyQueue_.EnQue(localCopyTensor);
|
||||
localCopyTensor = localCopyQueue_.DeQue<RecvXType>();
|
||||
DataCopyPad(dstWindow, localCopyTensor, xOutCopyParams);
|
||||
localCopyQueue_.FreeTensor<RecvXType>(localCopyTensor);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId,
|
||||
uint32_t srcTopkId)
|
||||
{
|
||||
LocalTensor<uint32_t> statusTensor = stateBuf_.AllocTensor<uint32_t>();
|
||||
GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN;
|
||||
GlobalTensor<uint32_t> stateGMTensor;
|
||||
stateGMTensor.SetGlobalBuffer((__gm__ uint32_t*)stateGM);
|
||||
DataCopy<uint32_t>(stateGMTensor, statusTensor, FLOAT_NUM_PER_ALIGN);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::WaitBuffCopy(uint32_t tokenIndex)
|
||||
{
|
||||
uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN;
|
||||
GM_ADDR stateGM = GetStateAddrByRankId(epRankId_) + tokenIndex * axisK_ * UB_32_ALIGN; // Calculate address offset
|
||||
GlobalTensor<float> stateGMTensor;
|
||||
stateGMTensor.SetGlobalBuffer((__gm__ float*)stateGM);
|
||||
float current = (float)0.0;
|
||||
float target = (float)1.0 * axisK_ * FLOAT_NUM_PER_ALIGN;
|
||||
SumParams sumPerKParams{1, calCount, calCount};
|
||||
LocalTensor<float> stateTensorLocal = stateBuf_.Get<float>();
|
||||
LocalTensor<float> tempStateTensorLocal = tempStateBuf_.Get<float>();
|
||||
while (current != target) {
|
||||
SyncFunc<AscendC::HardEvent::S_MTE2>();
|
||||
DataCopy<float>(stateTensorLocal, stateGMTensor, calCount);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
Sum(tempStateTensorLocal, stateTensorLocal, sumPerKParams);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
current = tempStateTensorLocal(0);
|
||||
}
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Duplicate<float>(tempStateTensorLocal, (float)0.0, calCount);
|
||||
SyncFunc<AscendC::HardEvent::V_MTE3>();
|
||||
DataCopy<float>(stateGMTensor, tempStateTensorLocal, calCount);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferAndWeightedSum(uint32_t tokenIndex,
|
||||
uint32_t startTokenIndex)
|
||||
{
|
||||
LocalTensor<float> tokenFloatLocal = tokenFloatBuf_.Get<float>();
|
||||
LocalTensor<float> weightedMulBufLocal = weightedMulBuf_.Get<float>();
|
||||
LocalTensor<float> sumFloatBufLocal = sumFloatBuf_.Get<float>();
|
||||
LocalTensor<float> topkWeightsLocal = topkWeightsBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> stateTensorLocal = stateBuf_.Get<uint32_t>();
|
||||
Duplicate(sumFloatBufLocal, static_cast<float>(0), axisH_);
|
||||
const DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
|
||||
|
||||
for (uint32_t topkId = 0U; topkId < axisK_; topkId++) {
|
||||
float scale = topkWeightsLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId);
|
||||
GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_;
|
||||
GlobalTensor<XType> localTokenTensor;
|
||||
localTokenTensor.SetGlobalBuffer((__gm__ XType*)localTokenAddr);
|
||||
|
||||
LocalTensor<XType> tmpToken = weightedSumQueue_.AllocTensor<XType>();
|
||||
const DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(tmpToken, localTokenTensor, xOutCopyParams, copyPadExtParams);
|
||||
weightedSumQueue_.EnQue(tmpToken);
|
||||
tmpToken = weightedSumQueue_.DeQue<XType>();
|
||||
Cast(tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(weightedMulBufLocal, tokenFloatLocal, scale, axisH_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, weightedMulBufLocal, axisH_);
|
||||
weightedSumQueue_.FreeTensor<XType>(tmpToken);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<XType> xOutLocal = xOutBuf_.Get<XType>();
|
||||
Cast(xOutLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, axisH_);
|
||||
SyncFunc<AscendC::HardEvent::V_MTE3>();
|
||||
DataCopyPad(xOutGlobal_[tokenIndex * axisH_], xOutLocal, xOutCopyParams);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferFromRemote()
|
||||
{
|
||||
if (axisBS_ == 0U) {
|
||||
return;
|
||||
}
|
||||
uint32_t tokenPerBlock = 0U, startTokenIndex = 0U, endTokenIndex = 0U;
|
||||
SplitCoreCal(axisBS_, tokenPerBlock, startTokenIndex, endTokenIndex);
|
||||
|
||||
if (tokenPerBlock == 0U) {
|
||||
return;
|
||||
}
|
||||
|
||||
tpipe_->Reset();
|
||||
tpipe_->InitBuffer(xOutBuf_, h32AlignRecvXLen_);
|
||||
tpipe_->InitBuffer(tokenFloatBuf_, h32AlignFloatLen_);
|
||||
tpipe_->InitBuffer(weightedMulBuf_, h256AlignFloatLen_);
|
||||
tpipe_->InitBuffer(sumFloatBuf_, h32AlignFloatLen_);
|
||||
tpipe_->InitBuffer(weightedSumQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_);
|
||||
tpipe_->InitBuffer(stateBuf_, (axisK_) * UB_32_ALIGN);
|
||||
tpipe_->InitBuffer(tempStateBuf_, (axisK_) * UB_32_ALIGN);
|
||||
tpipe_->InitBuffer(topkWeightsBuf_, tokenPerBlock * axisK_ * sizeof(float));
|
||||
|
||||
LocalTensor<float> topkWeightsLocal = topkWeightsBuf_.Get<float>();
|
||||
const DataCopyExtParams bskParams{1U, static_cast<uint32_t>(tokenPerBlock * axisK_ * sizeof(float)), 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<float> copyPadFloatParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(topkWeightsLocal, topkWeightsGM_[startTokenIndex * axisK_], bskParams, copyPadFloatParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
for (uint32_t tokenIndex = startTokenIndex; tokenIndex < endTokenIndex; tokenIndex++) {
|
||||
WaitBuffCopy(tokenIndex);
|
||||
SyncFunc<AscendC::HardEvent::MTE3_V>(); // Sync with result datacopy on same tensor
|
||||
ReadBufferAndWeightedSum(tokenIndex, startTokenIndex);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::Process()
|
||||
{
|
||||
if ASCEND_IS_AIV { // All AIV processing
|
||||
CopyBufferToShareAndSetStatus();
|
||||
ReadBufferFromRemote();
|
||||
}
|
||||
}
|
||||
|
||||
} // MoeCombineNormalImpl
|
||||
#endif // MOE_COMBINE_IMPL_H
|
||||
@@ -0,0 +1,33 @@
|
||||
#ifndef MOE_COMBINE_NORMAL_TILING_H
|
||||
#define MOE_COMBINE_NORMAL_TILING_H
|
||||
|
||||
#include <cstdint>
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
// a3
|
||||
struct MoeCombineNormalInfo {
|
||||
uint32_t epWorldSize;
|
||||
uint32_t tpWorldSize;
|
||||
uint32_t epRankId;
|
||||
uint32_t tpRankId;
|
||||
uint32_t expertShardType;
|
||||
uint32_t moeExpertNum;
|
||||
uint32_t moeExpertPerRankNum;
|
||||
uint32_t globalBs;
|
||||
uint32_t bs;
|
||||
uint32_t k;
|
||||
uint32_t h;
|
||||
uint32_t aivNum;
|
||||
uint64_t totalUbSize;
|
||||
uint64_t totalWinSize;
|
||||
float armAvgFactor;
|
||||
float epsilon;
|
||||
};
|
||||
struct MoeCombineNormalTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
Mc2CcTiling mc2CcTiling2;
|
||||
MoeCombineNormalInfo moeCombineNormalInfo;
|
||||
};
|
||||
|
||||
#endif //MOE_COMBINE_NORMAL_TILING_H
|
||||
49
csrc/moe_dispatch_normal/op_host/CMakeLists.txt
Normal file
49
csrc/moe_dispatch_normal/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MoeDispatchNormal
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
moe_dispatch_normal_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_dispatch_normal.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_moe_dispatch_normal.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerMoeDispatchNormalGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *topkIdx,
|
||||
const aclTensor *sendOffset,
|
||||
const aclTensor *sendTokenIdx,
|
||||
const aclTensor *recvOffset,
|
||||
const aclTensor *recvCount,
|
||||
char *groupEp,
|
||||
int64_t epWorldSize,
|
||||
int64_t epRankId,
|
||||
char *groupTpOptional,
|
||||
int64_t tpWorldSize,
|
||||
int64_t tpRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *recvX,
|
||||
const aclTensor *recvXScales,
|
||||
const aclTensor *assistInfoForCombine,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerMoeDispatchNormal(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x, const aclTensor *topkIdx,
|
||||
const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset, const aclTensor *recvCount,
|
||||
char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, int64_t tpWorldSize, int64_t tpRankId,
|
||||
int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, const aclTensor *recvX,
|
||||
const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerMoeDispatchNormalGetWorkspaceSize(x,
|
||||
topkIdx,
|
||||
sendOffset,
|
||||
sendTokenIdx,
|
||||
recvOffset,
|
||||
recvCount,
|
||||
groupEp,
|
||||
epWorldSize,
|
||||
epRankId,
|
||||
groupTpOptional,
|
||||
tpWorldSize,
|
||||
tpRankId,
|
||||
moeExpertNum,
|
||||
quantMode,
|
||||
globalBs,
|
||||
recvX,
|
||||
recvXScales,
|
||||
assistInfoForCombine,
|
||||
workspaceSize,
|
||||
executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeDispatchNormal(
|
||||
void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerMoeDispatchNormal(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
24
csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h
Normal file
24
csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef ACLNN_MOE_DISPATCH_NORMAL_H_
|
||||
#define ACLNN_MOE_DISPATCH_NORMAL_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x,
|
||||
const aclTensor *topkIdx, const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset,
|
||||
const aclTensor *recvCount, char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional,
|
||||
int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t quantMode, int64_t globalBs,
|
||||
const aclTensor *recvX, const aclTensor *recvXScales, const aclTensor *assistInfoForCombine,
|
||||
uint64_t *workspaceSize, aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormal(
|
||||
void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
92
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp
Normal file
92
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp
Normal file
@@ -0,0 +1,92 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class MoeDispatchNormal : public OpDef {
|
||||
public:
|
||||
explicit MoeDispatchNormal(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("topk_idx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Input("send_offset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("send_tokenIdx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("recv_offset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("recv_count")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Output("recv_x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Output("x_scales")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Output("assist_info_for_combine")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("group_ep").AttrType(REQUIRED).String();
|
||||
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
|
||||
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();
|
||||
this->Attr("group_tp").AttrType(OPTIONAL).String("");
|
||||
this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("moe_expert_num").AttrType(REQUIRED).Int();
|
||||
this->Attr("quant_mode").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("global_bs").AttrType(OPTIONAL).Int(0);
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_true")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->MC2().HcclGroup({"group_ep", "group_tp"});
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(MoeDispatchNormal);
|
||||
|
||||
} // namespace ops
|
||||
635
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp
Normal file
635
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp
Normal file
@@ -0,0 +1,635 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/moe_dispatch_normal_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace ge;
|
||||
namespace {
|
||||
class Mc2TilingUtils {
|
||||
public:
|
||||
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
|
||||
static uint64_t GetMaxWindowSize()
|
||||
{
|
||||
uint16_t defaultWindowSize = 200;
|
||||
if (getenv(HCCL_BUFFSIZE) == nullptr) {
|
||||
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
|
||||
} else {
|
||||
try {
|
||||
std::string envStr(getenv(HCCL_BUFFSIZE));
|
||||
defaultWindowSize = std::stoi(envStr);
|
||||
} catch (const std::invalid_argument &ia) {
|
||||
OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
|
||||
} catch (const std::out_of_range &oor) {
|
||||
OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
|
||||
}
|
||||
}
|
||||
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
|
||||
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
|
||||
return maxWindowSize;
|
||||
}
|
||||
};
|
||||
constexpr uint32_t X_INDEX = 0U;
|
||||
constexpr uint32_t EXPERT_IDS_INDEX = 1U;
|
||||
constexpr uint32_t SEND_OFFSET_INDEX = 2U;
|
||||
constexpr uint32_t SEND_TOKENIDX_INDEX = 3U;
|
||||
constexpr uint32_t RECV_OFFSET_INDEX = 4U;
|
||||
constexpr uint32_t RECV_COUNT_INDEX = 5U;
|
||||
|
||||
constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0U;
|
||||
constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1U;
|
||||
constexpr uint32_t OUTPUT_ASSIST_INFO_INDEX = 2U;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_GROUP_TP_INDEX = 3;
|
||||
constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4;
|
||||
constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 7;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 8;
|
||||
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t ONE_DIM = 1;
|
||||
constexpr uint32_t DYNAMIC_SCALE_DIM_NUM = 1;
|
||||
constexpr uint64_t INIT_TILINGKEY = 10000;
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
|
||||
constexpr uint32_t NO_SCALES = 0;
|
||||
constexpr uint32_t DYNAMIC_SCALES = 2;
|
||||
constexpr uint32_t OP_TYPE_ALL_GATHER = 6;
|
||||
|
||||
constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
|
||||
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
|
||||
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
|
||||
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
|
||||
constexpr int64_t BS_UPPER_BOUND = 8000; // Maximum bs
|
||||
|
||||
constexpr uint32_t TILINGKEY_TP_WORLD_SIZE = 100;
|
||||
constexpr uint32_t TP_WORLD_SIZE_TWO = 2;
|
||||
constexpr int64_t MOE_EXPERT_MAX_NUM = 512;
|
||||
constexpr int64_t K_MAX = 16;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t WORKSPACE_ELEMENT_OFFSET = 512;
|
||||
constexpr int64_t H_MIN = 1024;
|
||||
constexpr int64_t H_MAX = 7168;
|
||||
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
|
||||
constexpr uint64_t TRIPLE = 3;
|
||||
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
|
||||
constexpr uint64_t SCALE_EXPAND_IDX_BUFFER = 44UL; // scale32B + 3*4expandIdx
|
||||
constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
|
||||
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
|
||||
constexpr uint64_t UB_ALIGN = 32UL;
|
||||
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, MoeDispatchNormalTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeDispatchNormalInfo.epWorldSize);
|
||||
OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeDispatchNormalInfo.tpWorldSize);
|
||||
OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeDispatchNormalInfo.epRankId);
|
||||
OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeDispatchNormalInfo.tpRankId);
|
||||
OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeDispatchNormalInfo.moeExpertNum);
|
||||
OPS_LOG_D(nodeName, "quantMode is %u.", tilingData.moeDispatchNormalInfo.quantMode);
|
||||
OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeDispatchNormalInfo.globalBs);
|
||||
OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeDispatchNormalInfo.bs);
|
||||
OPS_LOG_D(nodeName, "k is %u.", tilingData.moeDispatchNormalInfo.k);
|
||||
OPS_LOG_D(nodeName, "h is %u.", tilingData.moeDispatchNormalInfo.h);
|
||||
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeDispatchNormalInfo.aivNum);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeDispatchNormalInfo.totalUbSize);
|
||||
OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeDispatchNormalInfo.totalWinSize);
|
||||
}
|
||||
|
||||
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "xShape is null."), return false);
|
||||
OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"xShape dims must be 2, but current dim num is %lu.",
|
||||
xStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_LOG_D(nodeName, "x dim0 = %ld", xDim0);
|
||||
OPS_LOG_D(nodeName, "x dim1 = %ld", xDim1);
|
||||
|
||||
const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIdShape is null."), return false);
|
||||
OPS_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expertIdShape dims must be 2, but current dim num is %lu.",
|
||||
expertIdStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXStorageShape == nullptr, OPS_LOG_E(nodeName, "expandXShape is null."), return false);
|
||||
OPS_CHECK(expandXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expandXShape dims must be 2, but current dim num is %lu.",
|
||||
expandXStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "expandX dim0 = %ld", expandXStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "expandX dim1 = %ld", expandXStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(
|
||||
dynamicScalesStorageShape == nullptr, OPS_LOG_E(nodeName, "dynamicScalesShape is null."), return false);
|
||||
OPS_CHECK(dynamicScalesStorageShape->GetStorageShape().GetDimNum() != DYNAMIC_SCALE_DIM_NUM,
|
||||
OPS_LOG_E(nodeName,
|
||||
"dynamicScalesShape dims must be %u, but current dim num is %lu.",
|
||||
DYNAMIC_SCALE_DIM_NUM,
|
||||
dynamicScalesStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "dynamicScales dim0 = %ld", dynamicScalesStorageShape->GetStorageShape().GetDim(0));
|
||||
}
|
||||
|
||||
const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "assistInfoShape is null."), return false);
|
||||
OPS_CHECK(assistInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
|
||||
OPS_LOG_E(nodeName,
|
||||
"assistInfoShape dims must be 1, but current dim num is %lu.",
|
||||
assistInfoStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "assistInfoForCombine dim0 = %ld", assistInfoStorageShape->GetStorageShape().GetDim(0));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK((xDesc->GetDataType() != ge::DT_BF16) && (xDesc->GetDataType() != ge::DT_FLOAT16),
|
||||
OPS_LOG_E(nodeName, "x dataType is invalid, dataType should be bf16 or float16, but is ."),
|
||||
return false);
|
||||
|
||||
auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false);
|
||||
OPS_CHECK(expertIdDesc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(nodeName, "expertId dataType is invalid, dataType should be int32, but is ."),
|
||||
return false);
|
||||
|
||||
auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false);
|
||||
if (quantMode != NO_SCALES) {
|
||||
OPS_CHECK(expandXDesc->GetDataType() != ge::DT_INT8,
|
||||
OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be int8, but is."),
|
||||
return false);
|
||||
} else {
|
||||
OPS_CHECK(expandXDesc->GetDataType() != xDesc->GetDataType(),
|
||||
OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be equal to x dataType , but is."),
|
||||
return false);
|
||||
}
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false);
|
||||
OPS_CHECK(dynamicScalesDesc->GetDataType() != ge::DT_FLOAT,
|
||||
OPS_LOG_E(nodeName, "dynamicScales dataType is invalid, dataType should be float, but is ."),
|
||||
return false);
|
||||
}
|
||||
|
||||
auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false);
|
||||
OPS_CHECK(assistInfoDesc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(nodeName, "assistInfoForCombine dataType is invalid, dataType should be int32, but is ."),
|
||||
return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "x format is invalid."),
|
||||
return false);
|
||||
|
||||
auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(expertIdDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "expertId format is invalid."),
|
||||
return false);
|
||||
|
||||
auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(expandXDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "expandX format is invalid."),
|
||||
return false);
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(dynamicScalesDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "dynamicScales format is invalid."),
|
||||
return false);
|
||||
}
|
||||
|
||||
auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(assistInfoDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "assistInfoForCombine format is invalid."),
|
||||
return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
MoeDispatchNormalTilingData &tilingData, std::string &groupEp, std::string &groupTp)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
|
||||
auto groupTpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_TP_INDEX));
|
||||
auto epWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_WORLD_SIZE_INDEX);
|
||||
auto tpWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_WORLD_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto tpRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
|
||||
|
||||
// Check for null
|
||||
OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "groupEpPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(quantModePtr == nullptr, OPS_LOG_E(nodeName, "quantModePtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if it meets uint32_t and other constraints
|
||||
int64_t moeExpertNum = *moeExpertNumPtr;
|
||||
int64_t epWorldSize = *epWorldSizePtr;
|
||||
OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.",
|
||||
MIN_EP_WORLD_SIZE,
|
||||
MAX_EP_WORLD_SIZE,
|
||||
epWorldSize),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.",
|
||||
MAX_TP_WORLD_SIZE,
|
||||
*tpWorldSizePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize),
|
||||
OPS_LOG_E(
|
||||
nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.", epWorldSize, *epRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (*tpWorldSizePtr > 1) {
|
||||
OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.",
|
||||
*tpWorldSizePtr,
|
||||
*tpRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "groupTpPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
groupTp = std::string(groupTpPtr);
|
||||
} else {
|
||||
OPS_CHECK(*tpRankIdPtr != 0,
|
||||
OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM),
|
||||
OPS_LOG_E(nodeName,
|
||||
"moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.",
|
||||
MOE_EXPERT_MAX_NUM,
|
||||
moeExpertNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(
|
||||
(*quantModePtr < static_cast<int64_t>(NO_SCALES)) || (*quantModePtr > static_cast<int64_t>(DYNAMIC_SCALES)),
|
||||
OPS_LOG_E(nodeName,
|
||||
"quantMode is invalid, only support [0, %u], but got quantMode=%ld.",
|
||||
DYNAMIC_SCALES,
|
||||
*quantModePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
int64_t moePerRankNum = moeExpertNum / epWorldSize;
|
||||
int64_t curDispatchStatusNum = moePerRankNum * epWorldSize;
|
||||
OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM),
|
||||
OPS_LOG_E(nodeName,
|
||||
"The moe experts num must meet the conditions,"
|
||||
" (moeExpertNum / epWorldSize * epWorldSize <= 1280, but cur is %ld.",
|
||||
curDispatchStatusNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
groupEp = std::string(groupEpPtr);
|
||||
tilingData.moeDispatchNormalInfo.epWorldSize = static_cast<uint32_t>(epWorldSize);
|
||||
tilingData.moeDispatchNormalInfo.tpWorldSize = static_cast<uint32_t>(*tpWorldSizePtr);
|
||||
tilingData.moeDispatchNormalInfo.epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
tilingData.moeDispatchNormalInfo.tpRankId = static_cast<uint32_t>(*tpRankIdPtr);
|
||||
tilingData.moeDispatchNormalInfo.moeExpertNum = static_cast<uint32_t>(moeExpertNum);
|
||||
tilingData.moeDispatchNormalInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckAttrs(
|
||||
gert::TilingContext *context, const char *nodeName, MoeDispatchNormalTilingData &tilingData, uint32_t &localMoeExpertNum)
|
||||
{
|
||||
uint32_t epWorldSize = tilingData.moeDispatchNormalInfo.epWorldSize;
|
||||
uint32_t tpWorldSize = tilingData.moeDispatchNormalInfo.tpWorldSize;
|
||||
uint32_t moeExpertNum = tilingData.moeDispatchNormalInfo.moeExpertNum;
|
||||
|
||||
// Validate if moe expert number can be evenly distributed across multiple machines
|
||||
localMoeExpertNum = moeExpertNum / epWorldSize;
|
||||
OPS_CHECK(moeExpertNum % epWorldSize != 0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"moeExpertNum should be divisible by epWorldSize, "
|
||||
"but moeExpertNum=%u, epWorldSize=%u.",
|
||||
moeExpertNum,
|
||||
epWorldSize),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(localMoeExpertNum <= 0,
|
||||
OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate input x dimension 0 and set bs
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0),
|
||||
OPS_LOG_E(
|
||||
nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.", BS_UPPER_BOUND, xDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData.moeDispatchNormalInfo.bs = static_cast<uint32_t>(xDim0);
|
||||
|
||||
// Validate globalBS
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
|
||||
OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBsPtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_LOG_D(nodeName, "MoeDispatchNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n", *globalBsPtr, xDim0, epWorldSize);
|
||||
OPS_CHECK(*globalBsPtr <= 0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"globalBS is invalid, should be positive, but got globalBS=%ld.",
|
||||
*globalBsPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tilingData.moeDispatchNormalInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
|
||||
MoeDispatchNormalTilingData &tilingData, const uint32_t quantMode, const int64_t localMoeExpertNum)
|
||||
{
|
||||
uint32_t A = 0U;
|
||||
uint32_t globalBs = tilingData.moeDispatchNormalInfo.globalBs;
|
||||
|
||||
// Validate input x dimension 1 and set h, bs already validated
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK((xDim1 < H_MIN) || (xDim1 > H_MAX),
|
||||
OPS_LOG_E(nodeName, "xShape dims1(H) should be in [%ld, %ld], but got %ld.", H_MIN, H_MAX, xDim1),
|
||||
return ge::GRAPH_FAILED); // 32-byte aligned
|
||||
tilingData.moeDispatchNormalInfo.h = static_cast<uint32_t>(xDim1);
|
||||
|
||||
// Validate expert_id dimensions and set k
|
||||
int64_t moeExpertNum = static_cast<int64_t>(tilingData.moeDispatchNormalInfo.moeExpertNum);
|
||||
const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
const int64_t expertIdsDim0 = expertIdStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t expertIdsDim1 = expertIdStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK(xDim0 != expertIdsDim0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"xShape's dim0 not equal to expertIdShape's dim0, "
|
||||
"xShape's dim0 is %ld, expertIdShape's dim0 is %ld.",
|
||||
xDim0,
|
||||
expertIdsDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((expertIdsDim1 <= 0) || (expertIdsDim1 > K_MAX) || (expertIdsDim1 > moeExpertNum),
|
||||
OPS_LOG_E(nodeName,
|
||||
"expertIdShape's dim1(k) should be in (0, min(%ld, moeExpertNum=%ld)], "
|
||||
"but got expertIdShape's dim1=%ld.",
|
||||
K_MAX,
|
||||
moeExpertNum,
|
||||
expertIdsDim1),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData.moeDispatchNormalInfo.k = static_cast<uint32_t>(expertIdsDim1);
|
||||
|
||||
A = globalBs;
|
||||
|
||||
// Validate expandX dimensions
|
||||
const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX);
|
||||
const int64_t expandXDim0 = expandXStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t expandXDim1 = expandXStorageShape->GetStorageShape().GetDim(1);
|
||||
|
||||
OPS_CHECK(xDim1 != expandXDim1,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expandX's dim1 not equal to xShape's dim1, "
|
||||
"xShape's dim1 is %ld, expandX's dim1 is %ld.",
|
||||
xDim1,
|
||||
expandXDim1),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate dynamicScales dimensions
|
||||
if (quantMode != NO_SCALES) {
|
||||
const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
const int64_t dynamicScalesDim0 = dynamicScalesStorageShape->GetStorageShape().GetDim(0);
|
||||
}
|
||||
|
||||
// Validate assistInfo dimensions
|
||||
const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX);
|
||||
const int64_t assistInfoDim0 = assistInfoStorageShape->GetStorageShape().GetDim(0);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckMoeDispatchNormal(
|
||||
gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDim(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params shape is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(!CheckTensorFormat(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params format is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void CalTilingKey(uint64_t &tilingKey, const uint32_t quantMode, const uint32_t tpWorldSize)
|
||||
{
|
||||
tilingKey += static_cast<uint64_t>(quantMode);
|
||||
if (tpWorldSize == TP_WORLD_SIZE_TWO) {
|
||||
tilingKey += static_cast<uint64_t>(TILINGKEY_TP_WORLD_SIZE);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context, MoeDispatchNormalTilingData *tiling, const std::string groupEp,
|
||||
const std::string groupTp)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "MoeDispatchNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str());
|
||||
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
|
||||
uint32_t opType2 = OP_TYPE_ALL_GATHER;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
std::string algConfigAllGatherStr = "AllGather=level0:ring";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
|
||||
|
||||
mc2CcTilingConfig.SetGroupName(groupTp);
|
||||
mc2CcTilingConfig.SetOpType(opType2);
|
||||
mc2CcTilingConfig.SetAlgConfig(algConfigAllGatherStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
workSpaces[0] = static_cast<uint64_t>(SYSTEM_NEED_WORKSPACE + WORKSPACE_ELEMENT_OFFSET * aivNum * aivNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeDispatchNormalA3TilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
MoeDispatchNormalTilingData *tilingData = context->GetTilingData<MoeDispatchNormalTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string groupEp = "";
|
||||
std::string groupTp = "";
|
||||
uint32_t quantMode = NO_SCALES;
|
||||
uint32_t localMoeExpertNum = 1;
|
||||
OPS_LOG_I(nodeName, "Enter MoeDispatchNormal tiling check func.");
|
||||
|
||||
// Get input parameter attributes
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp, groupTp) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
quantMode = tilingData->moeDispatchNormalInfo.quantMode;
|
||||
|
||||
// Check input/output dim, format, dataType
|
||||
OPS_CHECK(TilingCheckMoeDispatchNormal(context, nodeName, quantMode) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if attribute values are valid
|
||||
OPS_CHECK(CheckAttrs(context, nodeName, *tilingData, localMoeExpertNum) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Check attr failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t epRankId = tilingData->moeDispatchNormalInfo.epRankId;
|
||||
|
||||
// Check shape dimensions and assign h, k
|
||||
OPS_CHECK(
|
||||
CheckTensorShape(context, nodeName, *tilingData, quantMode, static_cast<int64_t>(localMoeExpertNum)) !=
|
||||
ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Check tensor shape failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate win area size
|
||||
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
|
||||
uint64_t h = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.h);
|
||||
uint64_t k = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.k);
|
||||
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.epWorldSize);
|
||||
uint64_t maxBs = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.globalBs) / epWorldSize;
|
||||
|
||||
// Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b)
|
||||
uint64_t tokenActualLen =
|
||||
((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_EXPAND_IDX_BUFFER;
|
||||
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
// Not considering dual stream size
|
||||
uint64_t actualSize = maxBs * k * tokenNeedSizeDispatch * DOUBLE_DATA_BUFFER;
|
||||
OPS_CHECK((actualSize > maxWindowSize),
|
||||
OPS_LOG_E(nodeName,
|
||||
"HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu,"
|
||||
" localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu,"
|
||||
" k = %lu, NEEDED_HCCL_BUFFSIZE(maxBs * k * tokenNeedSizeDispatch) = %luMB,"
|
||||
" HCCL_BUFFSIZE=%luMB.",
|
||||
maxBs,
|
||||
h,
|
||||
epWorldSize,
|
||||
localMoeExpertNum,
|
||||
tokenNeedSizeDispatch,
|
||||
k,
|
||||
actualSize / MB_SIZE + 1UL,
|
||||
maxWindowSize / MB_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData->moeDispatchNormalInfo.totalWinSize = maxWindowSize;
|
||||
OPS_LOG_D(nodeName, "windowSize = %lu", maxWindowSize);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp, groupTp);
|
||||
uint32_t tpWorldSize = tilingData->moeDispatchNormalInfo.tpWorldSize;
|
||||
uint64_t tilingKey = INIT_TILINGKEY;
|
||||
CalTilingKey(tilingKey, quantMode, tpWorldSize);
|
||||
OPS_LOG_D(nodeName, "tilingKey is %lu", tilingKey);
|
||||
context->SetTilingKey(tilingKey);
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
|
||||
context->SetBlockDim(blockDim);
|
||||
context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously
|
||||
tilingData->moeDispatchNormalInfo.totalUbSize = ubSize;
|
||||
tilingData->moeDispatchNormalInfo.aivNum = aivNum;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeDispatchNormalTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = MoeDispatchNormalA3TilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct MoeDispatchNormalCompileInfo {};
|
||||
ge::graphStatus TilingParseForMoeDispatchNormal(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(MoeDispatchNormal)
|
||||
.Tiling(MoeDispatchNormalTilingFunc)
|
||||
.TilingParse<MoeDispatchNormalCompileInfo>(TilingParseForMoeDispatchNormal);
|
||||
} // namespace optiling
|
||||
56
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp
Normal file
56
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "moe_dispatch_normal_tiling.h"
|
||||
#include "moe_dispatch_normal.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace MoeDispatchNormalImpl;
|
||||
|
||||
#define TILINGKEY_NO_QUANT 10000
|
||||
#define TILINGKEY_QUANT 10002
|
||||
|
||||
extern "C" __global__ __aicore__ void moe_dispatch_normal(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset,
|
||||
GM_ADDR send_token_idx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut,
|
||||
GM_ADDR assist_info_for_combine, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(MoeDispatchNormalTilingData);
|
||||
TPipe pipe;
|
||||
#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
|
||||
if (TILING_KEY_IS(TILINGKEY_NO_QUANT)) {
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM);
|
||||
MoeDispatchNormal<DTYPE_X, DTYPE_RECV_X, false, false, false> op;
|
||||
op.Init(x,
|
||||
expertIds,
|
||||
send_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
expandXOut,
|
||||
dynamicScalesOut,
|
||||
assist_info_for_combine,
|
||||
workspaceGM,
|
||||
&pipe,
|
||||
&tilingData);
|
||||
op.Process();
|
||||
return;
|
||||
}
|
||||
#elif (ORIG_DTYPE_RECV_X == DT_INT8)
|
||||
if (TILING_KEY_IS(TILINGKEY_QUANT)) {
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM);
|
||||
MoeDispatchNormal<DTYPE_X, DTYPE_RECV_X, true, false, false> op;
|
||||
op.Init(x,
|
||||
expertIds,
|
||||
send_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
expandXOut,
|
||||
dynamicScalesOut,
|
||||
assist_info_for_combine,
|
||||
workspaceGM,
|
||||
&pipe,
|
||||
&tilingData);
|
||||
op.Process();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
540
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h
Normal file
540
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h
Normal file
@@ -0,0 +1,540 @@
|
||||
#ifndef MOE_DISPATCH_NORMAL_H
|
||||
#define MOE_DISPATCH_NORMAL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_dispatch_normal_tiling.h"
|
||||
|
||||
namespace MoeDispatchNormalImpl {
|
||||
constexpr uint8_t BUFFER_NUM = 2;
|
||||
constexpr uint32_t STATE_OFFSET = 32U;
|
||||
constexpr uint32_t UB_ALIGN = 32U;
|
||||
constexpr uint8_t COMM_NUM = 2;
|
||||
constexpr uint8_t COMM_EP_IDX = 0;
|
||||
constexpr uint8_t COMM_TP_IDX = 1;
|
||||
|
||||
constexpr uint64_t WIN_STATE_OFFSET = 500UL * 1024UL;
|
||||
constexpr uint64_t STATE_WIN_OFFSET = 950UL * 1024UL;
|
||||
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
|
||||
constexpr uint32_t EXPAND_IDX_INFO = 3U;
|
||||
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL;
|
||||
|
||||
template <AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
#define CamTypeClass \
|
||||
typename XType, typename ExpandXOutType, bool DynamicQuant, bool IsSmoothScaleExist, bool IsShareExpertRank
|
||||
|
||||
#define CamTypeFunc XType, ExpandXOutType, DynamicQuant, IsSmoothScaleExist, IsShareExpertRank
|
||||
|
||||
using namespace AscendC;
|
||||
template <CamTypeClass>
|
||||
class MoeDispatchNormal {
|
||||
public:
|
||||
__aicore__ inline MoeDispatchNormal(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx,
|
||||
GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
||||
GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void InputToShare();
|
||||
__aicore__ inline void SetStatus();
|
||||
__aicore__ inline void WaitStatus();
|
||||
__aicore__ inline void ShareToOutput();
|
||||
__aicore__ inline void UpdateOutput();
|
||||
__aicore__ inline void FillTriple(LocalTensor<ExpandXOutType> &xOutTensor, uint32_t tokenIndex, uint32_t k);
|
||||
__aicore__ inline void QuantInit();
|
||||
__aicore__ inline void ReduceMaxInplace(const LocalTensor<float> &srcLocal, uint32_t count);
|
||||
__aicore__ inline void QuantProcess();
|
||||
__aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId)
|
||||
{
|
||||
uint32_t curRankId = ((ctxIdx == COMM_EP_IDX) ? epRankId : tpRankId);
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset + COMBINE_STATE_WIN_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) +
|
||||
winDataSizeOffset + COMBINE_STATE_WIN_OFFSET;
|
||||
}
|
||||
|
||||
__aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId)
|
||||
{
|
||||
uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId : tpRankId;
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState * WIN_STATE_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsExp) +
|
||||
dataState * WIN_STATE_OFFSET;
|
||||
}
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
GlobalTensor<XType> xGT;
|
||||
GlobalTensor<int32_t> expertIdsGT;
|
||||
GlobalTensor<int32_t> sendOffsetGT;
|
||||
GlobalTensor<int32_t> sendTokenIdxGT;
|
||||
GlobalTensor<int32_t> recvOffsetGT;
|
||||
GlobalTensor<int32_t> recvCountGT;
|
||||
GlobalTensor<float> dynamicScalesOutGT;
|
||||
GlobalTensor<int32_t> expandIdxOutGT;
|
||||
GlobalTensor<ExpandXOutType> dstGT;
|
||||
GlobalTensor<int32_t> dstStatusGT;
|
||||
|
||||
LocalTensor<XType> xInTensor;
|
||||
LocalTensor<ExpandXOutType> xOutTensor;
|
||||
LocalTensor<ExpandXOutType> xTmpTensor;
|
||||
LocalTensor<int32_t> expertIdsTensor;
|
||||
LocalTensor<int32_t> sendOffsetTensor;
|
||||
LocalTensor<int32_t> sendTokenIdxTensor;
|
||||
LocalTensor<int32_t> recvOffsetTensor;
|
||||
LocalTensor<int32_t> recvCountTensor;
|
||||
LocalTensor<int32_t> statusTensor;
|
||||
|
||||
TBuf<> expertIdsBuf;
|
||||
TBuf<> sendOffsetBuf;
|
||||
TBuf<> sendTokenIdxBuf;
|
||||
TBuf<> recvOffsetBuf;
|
||||
TBuf<> recvCountBuf;
|
||||
TBuf<> statusBuf;
|
||||
TBuf<> waitStatusBuf;
|
||||
TBuf<> gatherMaskOutBuf;
|
||||
TBuf<> scalarBuf;
|
||||
TBuf<> tokenCastFloatBuf;
|
||||
TBuf<> tokenAbsFloatBuf;
|
||||
|
||||
GM_ADDR expandXOutGM;
|
||||
GM_ADDR shareGM;
|
||||
|
||||
uint32_t batchSize{0};
|
||||
uint32_t globalBatchSize{0};
|
||||
uint32_t h{0};
|
||||
uint32_t topK{0};
|
||||
uint32_t blockNum{0};
|
||||
uint32_t blockIdx{0};
|
||||
uint32_t epRankSize{0};
|
||||
uint32_t epRankId{0};
|
||||
uint32_t tpRankSize{0};
|
||||
uint32_t tpRankId{0};
|
||||
uint32_t moeExpertNum{0};
|
||||
uint32_t moeExpertNumPerRank{0};
|
||||
|
||||
uint32_t hUBAlignSize{0};
|
||||
uint32_t hOutGMAlignSize{0};
|
||||
uint32_t hOutUBAlignSize{0};
|
||||
uint32_t hGMAlignCnt{0};
|
||||
uint32_t expandIdxStartIdx{0};
|
||||
uint32_t expertIdsCnt{0};
|
||||
uint32_t stateOffset{0};
|
||||
uint32_t dataState{0};
|
||||
uint32_t winDataSizeOffset{0};
|
||||
|
||||
uint32_t startStatusId;
|
||||
uint32_t endStatusId;
|
||||
uint32_t statusNumPerCore;
|
||||
uint32_t remainStatus;
|
||||
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> xOutQueue;
|
||||
|
||||
__gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr};
|
||||
|
||||
DataCopyExtParams hCommuCopyOutParams;
|
||||
};
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset,
|
||||
GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut,
|
||||
GM_ADDR expandIdxOut, GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
blockIdx = GetBlockIdx();
|
||||
|
||||
winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>();
|
||||
|
||||
GlobalTensor<int32_t> selfDataStatusTensor;
|
||||
GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp);
|
||||
selfDataStatusTensor.SetGlobalBuffer(
|
||||
(__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET + blockIdx * WIN_ADDR_ALIGN));
|
||||
|
||||
batchSize = tilingData->moeDispatchNormalInfo.bs;
|
||||
globalBatchSize = tilingData->moeDispatchNormalInfo.globalBs;
|
||||
h = tilingData->moeDispatchNormalInfo.h;
|
||||
topK = tilingData->moeDispatchNormalInfo.k;
|
||||
blockNum = tilingData->moeDispatchNormalInfo.aivNum;
|
||||
epRankSize = tilingData->moeDispatchNormalInfo.epWorldSize;
|
||||
epRankId = tilingData->moeDispatchNormalInfo.epRankId;
|
||||
moeExpertNum = tilingData->moeDispatchNormalInfo.moeExpertNum;
|
||||
moeExpertNumPerRank = moeExpertNum / epRankSize;
|
||||
|
||||
xGT.SetGlobalBuffer((__gm__ XType *)x);
|
||||
expertIdsGT.SetGlobalBuffer((__gm__ int32_t *)expertIds);
|
||||
sendOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(send_offset));
|
||||
sendTokenIdxGT.SetGlobalBuffer((__gm__ int32_t *)(send_tokenIdx));
|
||||
recvOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(recv_offset));
|
||||
recvCountGT.SetGlobalBuffer((__gm__ int32_t *)(recv_count));
|
||||
dynamicScalesOutGT.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
|
||||
expandIdxOutGT.SetGlobalBuffer((__gm__ int32_t *)(expandIdxOut));
|
||||
|
||||
expandXOutGM = expandXOut;
|
||||
|
||||
hUBAlignSize = Ceil(h * sizeof(ExpandXOutType), UB_ALIGN) * UB_ALIGN;
|
||||
uint32_t hScaleSizeAlign = hUBAlignSize + UB_ALIGN;
|
||||
expandIdxStartIdx = hScaleSizeAlign / sizeof(int32_t);
|
||||
|
||||
uint32_t hScaleIdxSize = hScaleSizeAlign + EXPAND_IDX_INFO * sizeof(int32_t);
|
||||
hOutGMAlignSize = Ceil(hScaleIdxSize, WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
hGMAlignCnt = hOutGMAlignSize / sizeof(ExpandXOutType);
|
||||
|
||||
expertIdsCnt = batchSize * topK;
|
||||
statusNumPerCore = moeExpertNum / blockNum;
|
||||
remainStatus = moeExpertNum % blockNum;
|
||||
startStatusId = statusNumPerCore * blockIdx;
|
||||
if (blockIdx < remainStatus) {
|
||||
statusNumPerCore += 1;
|
||||
startStatusId += blockIdx;
|
||||
} else {
|
||||
startStatusId += remainStatus;
|
||||
}
|
||||
endStatusId = startStatusId + statusNumPerCore;
|
||||
stateOffset = STATE_OFFSET;
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfDataStatusTensor);
|
||||
dataState = selfDataStatusTensor(0);
|
||||
if (dataState == 0) {
|
||||
selfDataStatusTensor(0) = 1;
|
||||
} else {
|
||||
selfDataStatusTensor(0) = 0;
|
||||
}
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfDataStatusTensor);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
uint64_t hSizeAlignCombine = Ceil(h * sizeof(XType), WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
winDataSizeOffset = dataState * (tilingData->moeDispatchNormalInfo.totalWinSize / 2) +
|
||||
globalBatchSize / epRankSize * topK * hSizeAlignCombine;
|
||||
shareGM = GetWindAddrByRankId(COMM_EP_IDX, epRankId);
|
||||
|
||||
hOutUBAlignSize = Ceil(hScaleIdxSize, UB_ALIGN) * UB_ALIGN;
|
||||
if constexpr (DynamicQuant) {
|
||||
QuantInit();
|
||||
} else {
|
||||
tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 2 * 14K = 28K
|
||||
}
|
||||
|
||||
tpipe_->InitBuffer(sendOffsetBuf, moeExpertNum * sizeof(int32_t)); // 4 * moeNum
|
||||
sendOffsetTensor = sendOffsetBuf.Get<int32_t>();
|
||||
|
||||
hCommuCopyOutParams = {1U, static_cast<uint32_t>(hScaleIdxSize), 0U, 0U, 0U};
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::QuantInit()
|
||||
{
|
||||
uint32_t hAlignSize = Ceil(h * sizeof(XType), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(xInQueue, BUFFER_NUM, hAlignSize); // 14K * 2
|
||||
tpipe_->InitBuffer(xOutQueue, BUFFER_NUM, hOutUBAlignSize); // 7K * 2
|
||||
|
||||
tpipe_->InitBuffer(tokenCastFloatBuf, h * sizeof(float)); // 28K
|
||||
tpipe_->InitBuffer(tokenAbsFloatBuf, h * sizeof(float)); // 28K
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::ReduceMaxInplace(
|
||||
const LocalTensor<float> &srcLocal, uint32_t count)
|
||||
{
|
||||
uint64_t repsFp32 = count >> 6; // 6 is count / elemPerRefFp32
|
||||
uint64_t offsetsFp32 = repsFp32 << 6; // 6 is repsFp32 * elemPerRefFp32
|
||||
uint64_t remsFp32 = count & 0x3f; // 0x3f 63, count % elemPerRefFp32
|
||||
const uint64_t elemPerRefFp32 = 64UL; // 256 bit / sizeof(float)
|
||||
if (likely(repsFp32 > 1)) {
|
||||
// 8 is rep stride
|
||||
Max(srcLocal, srcLocal[elemPerRefFp32], srcLocal, elemPerRefFp32, repsFp32 - 1, {1, 1, 1, 0, 8, 0});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(remsFp32 > 0) && unlikely(offsetsFp32 > 0)) {
|
||||
Max(srcLocal, srcLocal[offsetsFp32], srcLocal, remsFp32, 1, {1, 1, 1, 0, 8, 0});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
uint32_t mask = (repsFp32 > 0) ? elemPerRefFp32 : count;
|
||||
// 8 is rep stride
|
||||
WholeReduceMax(srcLocal, srcLocal, mask, 1, 8, 1, 8);
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::QuantProcess()
|
||||
{
|
||||
float dynamicScale = 0.0;
|
||||
LocalTensor<float> floatLocalTemp;
|
||||
floatLocalTemp = tokenCastFloatBuf.Get<float>();
|
||||
|
||||
Cast(floatLocalTemp, xInTensor, RoundMode::CAST_NONE, h);
|
||||
xInQueue.FreeTensor<XType>(xInTensor);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
if constexpr (DynamicQuant) {
|
||||
LocalTensor<float> floatLocalAbsTemp = tokenAbsFloatBuf.Get<float>();
|
||||
|
||||
Abs(floatLocalAbsTemp, floatLocalTemp, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceMaxInplace(floatLocalAbsTemp, h);
|
||||
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
dynamicScale = float(127.0) / (floatLocalAbsTemp.GetValue(0) + 1e-12f);
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Muls(floatLocalTemp, floatLocalTemp, dynamicScale, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
LocalTensor<half> halfLocalTemp = floatLocalTemp.ReinterpretCast<half>();
|
||||
LocalTensor<int32_t> int32LocalTemp = floatLocalTemp.ReinterpretCast<int32_t>();
|
||||
Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, h);
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xOutTensor, halfLocalTemp, RoundMode::CAST_TRUNC, h);
|
||||
|
||||
floatLocalTemp = xOutTensor.template ReinterpretCast<float>();
|
||||
floatLocalTemp.SetValue(hUBAlignSize / sizeof(float), float(1.0) / dynamicScale); // int8->float32
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::FillTriple(
|
||||
LocalTensor<ExpandXOutType> &xOutTensor, uint32_t tokenIndex, uint32_t k)
|
||||
{
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
LocalTensor<int32_t> xOutTint32 = xOutTensor.template ReinterpretCast<int32_t>();
|
||||
xOutTint32(expandIdxStartIdx) = epRankId;
|
||||
xOutTint32(expandIdxStartIdx + 1) = tokenIndex;
|
||||
xOutTint32(expandIdxStartIdx + 2) = k;
|
||||
SyncFunc<AscendC::HardEvent::S_MTE3>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::InputToShare()
|
||||
{
|
||||
DataCopyExtParams sendOffsetParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> sendOffsetCopyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(sendOffsetTensor, sendOffsetGT, sendOffsetParams, sendOffsetCopyPadParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
uint32_t startTokenId, endTokenId, sendTokenNum, remainTokenNum;
|
||||
sendTokenNum = expertIdsCnt / blockNum;
|
||||
remainTokenNum = expertIdsCnt % blockNum;
|
||||
startTokenId = sendTokenNum * blockIdx;
|
||||
if (blockIdx < remainTokenNum) {
|
||||
sendTokenNum += 1;
|
||||
startTokenId += blockIdx;
|
||||
} else {
|
||||
startTokenId += remainTokenNum;
|
||||
}
|
||||
endTokenId = startTokenId + sendTokenNum;
|
||||
|
||||
if (startTokenId >= expertIdsCnt) {
|
||||
return;
|
||||
}
|
||||
tpipe_->InitBuffer(expertIdsBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48
|
||||
tpipe_->InitBuffer(sendTokenIdxBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48
|
||||
expertIdsTensor = expertIdsBuf.Get<int32_t>();
|
||||
sendTokenIdxTensor = sendTokenIdxBuf.Get<int32_t>();
|
||||
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams sendTokenIdxParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<XType> tokenCopyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(expertIdsTensor, expertIdsGT[startTokenId], expertIdsCntParams, copyPadExtParams);
|
||||
DataCopyPad(sendTokenIdxTensor, sendTokenIdxGT[startTokenId], sendTokenIdxParams, copyPadExtParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
DataCopyExtParams xCopyParams = {1U, static_cast<uint32_t>(h * sizeof(XType)), 0U, 0U, 0U};
|
||||
for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) {
|
||||
uint32_t dstExpertId = expertIdsTensor(tokenIndex - startTokenId);
|
||||
int32_t curExpertCnt = sendTokenIdxTensor(tokenIndex - startTokenId);
|
||||
int32_t dstExpertOffset = sendOffsetTensor(dstExpertId);
|
||||
GM_ADDR rankGM =
|
||||
(__gm__ uint8_t *)(shareGM + hOutGMAlignSize * (dstExpertOffset + curExpertCnt));
|
||||
dstGT.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM);
|
||||
|
||||
if constexpr (DynamicQuant) {
|
||||
xInTensor = xInQueue.AllocTensor<XType>();
|
||||
DataCopyPad(xInTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams);
|
||||
xInQueue.EnQue(xInTensor);
|
||||
xInTensor = xInQueue.DeQue<XType>();
|
||||
xOutTensor = xOutQueue.AllocTensor<ExpandXOutType>();
|
||||
QuantProcess();
|
||||
xOutQueue.EnQue(xOutTensor);
|
||||
xOutTensor = xOutQueue.DeQue<ExpandXOutType>();
|
||||
FillTriple(xOutTensor, tokenIndex / topK, tokenIndex % topK);
|
||||
DataCopyPad(dstGT, xOutTensor, hCommuCopyOutParams);
|
||||
xOutQueue.FreeTensor(xOutTensor);
|
||||
} else {
|
||||
xTmpTensor = xQueue.AllocTensor<ExpandXOutType>();
|
||||
DataCopyPad(xTmpTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams);
|
||||
xQueue.EnQue(xTmpTensor);
|
||||
xTmpTensor = xQueue.DeQue<ExpandXOutType>();
|
||||
FillTriple(xTmpTensor, tokenIndex / topK, tokenIndex % topK);
|
||||
DataCopyPad(dstGT, xTmpTensor, hCommuCopyOutParams);
|
||||
xQueue.FreeTensor<ExpandXOutType>(xTmpTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::SetStatus()
|
||||
{
|
||||
uint32_t startExpId, endExpId, expNumPerCore;
|
||||
expNumPerCore = statusNumPerCore;
|
||||
startExpId = startStatusId;
|
||||
endExpId = endStatusId;
|
||||
if (startExpId > moeExpertNum) {
|
||||
SyncAll<true>();
|
||||
return;
|
||||
}
|
||||
uint32_t statusCntAlign = Ceil(expNumPerCore, 8) * 8;
|
||||
tpipe_->InitBuffer(statusBuf, statusCntAlign * UB_ALIGN); // moeNum / 48 * 32
|
||||
statusTensor = statusBuf.Get<int32_t>();
|
||||
Duplicate<int32_t>(statusTensor, 0, expNumPerCore * 8);
|
||||
uint64_t mask[2] = {0x101010101010101, 0};
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate<int32_t>(statusTensor, 0x3F800000, mask, statusCntAlign / 8, 1, 8);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
SyncAll<true>();
|
||||
for (uint32_t i = startExpId; i < endExpId; ++i) {
|
||||
uint32_t targetRankId = i / moeExpertNumPerRank;
|
||||
uint32_t offset = stateOffset * (epRankId + i % moeExpertNumPerRank * epRankSize);
|
||||
GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, targetRankId) + offset);
|
||||
dstStatusGT.SetGlobalBuffer((__gm__ int32_t *)rankGM);
|
||||
DataCopy<int32_t>(dstStatusGT, statusTensor[(i - startExpId) * 8], 8UL);
|
||||
}
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::WaitStatus()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
uint32_t waitStatusBufSize = (((statusNumPerCore * UB_ALIGN) > 256) ? (statusNumPerCore * UB_ALIGN) : 256);
|
||||
tpipe_->InitBuffer(waitStatusBuf, waitStatusBufSize); // moeNum /48 * 32B = 43 * 32B
|
||||
tpipe_->InitBuffer(gatherMaskOutBuf, moeExpertNum * sizeof(float)); // moeNum * 4B
|
||||
tpipe_->InitBuffer(scalarBuf, UB_ALIGN * 3); // 96B
|
||||
tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 28K
|
||||
tpipe_->InitBuffer(recvOffsetBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B
|
||||
tpipe_->InitBuffer(recvCountBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B
|
||||
|
||||
recvOffsetTensor = recvOffsetBuf.Get<int32_t>();
|
||||
recvCountTensor = recvCountBuf.Get<int32_t>();
|
||||
DataCopyExtParams recvOffsetParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams recvCountParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(recvOffsetTensor, recvOffsetGT, recvOffsetParams, copyPadExtParams);
|
||||
DataCopyPad(recvCountTensor, recvCountGT, recvCountParams, copyPadExtParams);
|
||||
|
||||
if (startStatusId >= moeExpertNum) {
|
||||
SyncAll<true>();
|
||||
return;
|
||||
}
|
||||
|
||||
LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf.Get<float>();
|
||||
LocalTensor<float> statusSumOutTensor = scalarBuf.GetWithOffset<float>(UB_ALIGN / sizeof(float), UB_ALIGN);
|
||||
LocalTensor<float> statusFp32Tensor = waitStatusBuf.Get<float>();
|
||||
GlobalTensor<float> windowInstatusFp32Tensor;
|
||||
windowInstatusFp32Tensor.SetGlobalBuffer((__gm__ float *)(GetWindStateAddrByRankId(COMM_EP_IDX, epRankId)));
|
||||
uint32_t mask = 1;
|
||||
float compareTarget = static_cast<float>(1.0) * statusNumPerCore;
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
DataCopyParams intriParams{static_cast<uint16_t>(statusNumPerCore), 1, 0, 0};
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
while (sumOfFlag != compareTarget) {
|
||||
DataCopy(statusFp32Tensor, windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)], intriParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
ReduceSum(statusSumOutTensor, statusFp32Tensor, gatherMaskOutTensor, mask, statusNumPerCore, 1);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||
}
|
||||
|
||||
// Clear state
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
DataCopyParams intriOutParams{static_cast<uint16_t>(statusNumPerCore), 1, 0, 0};
|
||||
uint64_t duplicateMask[2] = {0x101010101010101, 0};
|
||||
LocalTensor<int32_t> cleanStateTensor = waitStatusBuf.Get<int32_t>();
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Duplicate<int32_t>(cleanStateTensor, 0, duplicateMask, Ceil(statusNumPerCore, 8), 1, 8);
|
||||
SyncFunc<AscendC::HardEvent::V_MTE3>();
|
||||
DataCopy(windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)],
|
||||
cleanStateTensor.ReinterpretCast<float>(),
|
||||
intriOutParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
SyncAll<true>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::ShareToOutput()
|
||||
{
|
||||
if (startStatusId >= moeExpertNum) {
|
||||
return;
|
||||
}
|
||||
uint32_t fromRank, count, preCount, recvOffset, targetOffset;
|
||||
DataCopyPadExtParams<ExpandXOutType> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyExtParams dataCopyExandIdxParams{1U, sizeof(int32_t) * EXPAND_IDX_INFO, 0U, 0U, 0U};
|
||||
DataCopyExtParams dataCopyOutParams{1U, static_cast<uint32_t>(statusNumPerCore * sizeof(int32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams expandXCopyParams = {1U, static_cast<uint32_t>(h * sizeof(ExpandXOutType)), 0U, 0U, 0U};
|
||||
LocalTensor<int32_t> xTmpTensorInt;
|
||||
AscendC::TQueSync<PIPE_MTE2, PIPE_S> recvCountLocalSync;
|
||||
recvCountLocalSync.SetFlag(0);
|
||||
recvCountLocalSync.WaitFlag(0);
|
||||
for (uint32_t i = startStatusId; i < endStatusId; ++i) {
|
||||
preCount = 0;
|
||||
if (likely(i != 0)) {
|
||||
preCount = recvCountTensor(i - 1);
|
||||
}
|
||||
fromRank = i % epRankSize;
|
||||
count = recvCountTensor(i) - preCount;
|
||||
recvOffset = recvOffsetTensor(i);
|
||||
targetOffset = preCount;
|
||||
GM_ADDR recvStart =
|
||||
(__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, fromRank)) + recvOffset * hOutGMAlignSize;
|
||||
GlobalTensor<ExpandXOutType> srcTokenGT, dstTokenGT;
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
srcTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(recvStart + j * hOutGMAlignSize));
|
||||
xTmpTensor = xQueue.AllocTensor<ExpandXOutType>();
|
||||
DataCopyPad(xTmpTensor, srcTokenGT, hCommuCopyOutParams, copyPadExtParams);
|
||||
xQueue.EnQue(xTmpTensor);
|
||||
xTmpTensor = xQueue.DeQue<ExpandXOutType>();
|
||||
xTmpTensorInt = xTmpTensor.template ReinterpretCast<int32_t>();
|
||||
DataCopyPad(expandIdxOutGT[(targetOffset + j) * EXPAND_IDX_INFO],
|
||||
xTmpTensorInt[expandIdxStartIdx],
|
||||
dataCopyExandIdxParams);
|
||||
if constexpr (DynamicQuant) {
|
||||
DataCopyExtParams floatDataCopyParams = {1U, sizeof(float), 0U, 0U, 0U};
|
||||
LocalTensor<float> xOutFp32Tensor = xTmpTensor.template ReinterpretCast<float>();
|
||||
DataCopyPad(dynamicScalesOutGT[targetOffset + j],
|
||||
xOutFp32Tensor[hUBAlignSize / sizeof(float)],
|
||||
floatDataCopyParams);
|
||||
}
|
||||
dstTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM) + (targetOffset + j) * h, h);
|
||||
DataCopyPad(dstTokenGT, xTmpTensor, expandXCopyParams);
|
||||
xQueue.FreeTensor(xTmpTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::Process()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
InputToShare();
|
||||
SetStatus();
|
||||
WaitStatus();
|
||||
ShareToOutput();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MoeDispatchNormalImpl
|
||||
#endif
|
||||
@@ -0,0 +1,30 @@
|
||||
#ifndef MOE_DISPATCH_NORMAL_TILING_H
|
||||
#define MOE_DISPATCH_NORMAL_TILING_H
|
||||
|
||||
struct MoeDispatchNormalInfo {
|
||||
uint32_t epWorldSize; // epWorldSize
|
||||
uint32_t tpWorldSize; // tpWorldSize
|
||||
uint32_t epRankId; // epRankId
|
||||
uint32_t tpRankId; // tpRankId
|
||||
uint32_t moeExpertNum; // moe expert number
|
||||
uint32_t quantMode; // quant mode
|
||||
uint32_t globalBs; // globalBs = BS * worldSize
|
||||
uint32_t bs; // bs
|
||||
uint32_t k; // k
|
||||
uint32_t h; // h
|
||||
uint32_t aivNum; // aivNum
|
||||
bool isQuant; // whether quant or not
|
||||
bool reserved2; // reserved
|
||||
bool reserved3; // reserved
|
||||
uint64_t totalUbSize; // epWorldSize
|
||||
uint64_t totalWinSize;
|
||||
};
|
||||
|
||||
struct MoeDispatchNormalTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
Mc2CcTiling mc2CcTiling2;
|
||||
MoeDispatchNormalInfo moeDispatchNormalInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
49
csrc/notify_dispatch/op_host/CMakeLists.txt
Normal file
49
csrc/notify_dispatch/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME NotifyDispatch
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
notify_dispatch.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
notify_dispatch_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_notify_dispatch.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
84
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp
Normal file
84
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_notify_dispatch.h"
|
||||
|
||||
extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
extern aclnnStatus aclnnInnerNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerNotifyDispatchGetWorkspaceSize(
|
||||
sendData,
|
||||
tokenPerExpertData,
|
||||
sendCount,
|
||||
numTokens,
|
||||
commGroup,
|
||||
rankSize,
|
||||
rankId,
|
||||
localRankSize,
|
||||
localRankId,
|
||||
sendDataOffset,
|
||||
recvData,
|
||||
workspaceSize,
|
||||
executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerNotifyDispatch(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
61
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h
Normal file
61
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
#ifndef ACLNN_NOTIFY_DISPATCH_H_
|
||||
#define ACLNN_NOTIFY_DISPATCH_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* funtion: aclnnNotifyDispatchGetWorkspaceSize
|
||||
* parameters :
|
||||
* sendData : required
|
||||
* tokenPerExpertData : required
|
||||
* sendCount : required
|
||||
* numTokens : required
|
||||
* commGroup : required
|
||||
* rankSize : required
|
||||
* rankId : required
|
||||
* localRankSize : required
|
||||
* localRankId : required
|
||||
* sendDataOffset : required
|
||||
* recvData : required
|
||||
* workspaceSize : size of workspace(output).
|
||||
* executor : executor context(output).
|
||||
*/
|
||||
__attribute__((visibility("default")))
|
||||
aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
/* funtion: aclnnNotifyDispatch
|
||||
* parameters :
|
||||
* workspace : workspace memory addr(input).
|
||||
* workspaceSize : size of workspace(input).
|
||||
* executor : executor context(input).
|
||||
* stream : acl stream.
|
||||
*/
|
||||
__attribute__((visibility("default")))
|
||||
aclnnStatus aclnnNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
60
csrc/notify_dispatch/op_host/notify_dispatch.cpp
Normal file
60
csrc/notify_dispatch/op_host/notify_dispatch.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class NotifyDispatch : public OpDef {
|
||||
public:
|
||||
explicit NotifyDispatch(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("sendData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("tokenPerExpertData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("sendDataOffset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("recvData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("sendCount").Int();
|
||||
this->Attr("num_tokens").Int();
|
||||
this->Attr("comm_group").String();
|
||||
this->Attr("rank_size").Int();
|
||||
this->Attr("rank_id").Int();
|
||||
this->Attr("local_rank_size").Int();
|
||||
this->Attr("local_rank_id").Int();
|
||||
|
||||
OpAICoreConfig aicore_config_base;
|
||||
aicore_config_base.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
OpAICoreConfig aicore_config_A2 = aicore_config_base;
|
||||
aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false");
|
||||
|
||||
OpAICoreConfig aicore_config = aicore_config_base;
|
||||
aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->AICore().AddConfig("ascend910b", aicore_config_A2);
|
||||
this->MC2().HcclGroup("comm_group");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(NotifyDispatch);
|
||||
} // namespace ops
|
||||
306
csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp
Normal file
306
csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp
Normal file
@@ -0,0 +1,306 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/notify_dispatch_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
#include "experiment/platform/platform/platform_infos_def.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
class Mc2TilingUtils {
|
||||
public:
|
||||
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
|
||||
static uint64_t GetMaxWindowSize()
|
||||
{
|
||||
uint16_t defaultWindowSize = 200;
|
||||
if (getenv(HCCL_BUFFSIZE) == nullptr) {
|
||||
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
|
||||
} else {
|
||||
try {
|
||||
std::string envStr(getenv(HCCL_BUFFSIZE));
|
||||
defaultWindowSize = std::stoi(envStr);
|
||||
} catch (const std::invalid_argument &ia) {
|
||||
OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
|
||||
} catch (const std::out_of_range &oor) {
|
||||
OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
|
||||
}
|
||||
}
|
||||
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
|
||||
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
|
||||
return maxWindowSize;
|
||||
}
|
||||
};
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll
|
||||
|
||||
constexpr uint32_t INPUT_SEND_DATA_INDEX = 0;
|
||||
constexpr uint32_t INPUT_TOKEN_PER_EXPERT_INDEX = 1;
|
||||
|
||||
constexpr uint32_t OUTPUT_SEND_DATA_OFFSET_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_RECV_DATA_INDEX = 1;
|
||||
|
||||
constexpr uint32_t ATTR_SEND_COUNT_INDEX = 0;
|
||||
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 1;
|
||||
constexpr uint32_t ATTR_COMM_GROUP_INDEX = 2;
|
||||
constexpr uint32_t ATTR_RANK_SIZE_INDEX = 3;
|
||||
constexpr uint32_t ATTR_RANK_ID_INDEX = 4;
|
||||
constexpr uint32_t ATTR_LOCAL_RANK_SIZE_INDEX = 5;
|
||||
constexpr uint32_t ATTR_LOCAL_RANK_ID_INDEX = 6;
|
||||
|
||||
const size_t MAX_GROUP_NAME_LENGTH = 128UL;
|
||||
const int64_t MAX_COMM_WORLD_SIZE = 384;
|
||||
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;
|
||||
constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes
|
||||
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
|
||||
|
||||
constexpr static int TILING_KEY_FLOAT16 = 20;
|
||||
constexpr static int TILING_KEY_BFLOAT16 = 21;
|
||||
constexpr static int TILING_KEY_FLOAT = 22;
|
||||
constexpr static int TILING_KEY_INT = 23;
|
||||
constexpr static int TILING_KEY_A2_TYPE = 100;
|
||||
|
||||
constexpr static int ALL_TO_ALL_CORE_NUM = 32;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, NotifyDispatchTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "rankSize is %u.", tilingData.notifyDispatchInfo.rankSize);
|
||||
OPS_LOG_D(nodeName, "rankId is %u.", tilingData.notifyDispatchInfo.rankId);
|
||||
OPS_LOG_D(nodeName, "localRankSize is %u.", tilingData.notifyDispatchInfo.localRankSize);
|
||||
OPS_LOG_D(nodeName, "localRankId is %u.", tilingData.notifyDispatchInfo.localRankId);
|
||||
OPS_LOG_D(nodeName, "sendCount is %u.", tilingData.notifyDispatchInfo.sendCount);
|
||||
OPS_LOG_D(nodeName, "numTokens is %u.", tilingData.notifyDispatchInfo.numTokens);
|
||||
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.notifyDispatchInfo.aivNum);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.notifyDispatchInfo.totalUbSize);
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
NotifyDispatchTilingData &tilingData, std::string &commGroup)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto sendCountPtr = attrs->GetAttrPointer<int64_t>(ATTR_SEND_COUNT_INDEX);
|
||||
auto numTokenPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_TOKENS_INDEX);
|
||||
auto commGroupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_COMM_GROUP_INDEX));
|
||||
auto rankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_SIZE_INDEX);
|
||||
auto rankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_ID_INDEX);
|
||||
auto localRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_LOCAL_RANK_SIZE_INDEX);
|
||||
auto localRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_LOCAL_RANK_ID_INDEX);
|
||||
|
||||
OPS_CHECK((commGroupPtr == nullptr) || (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "commGroupPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(sendCountPtr == nullptr, OPS_LOG_E(nodeName, "sendCountPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numTokenPtr == nullptr, OPS_LOG_E(nodeName, "numTokenPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(rankSizePtr == nullptr, OPS_LOG_E(nodeName, "rankSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(rankIdPtr == nullptr, OPS_LOG_E(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(
|
||||
localRankSizePtr == nullptr, OPS_LOG_E(nodeName, "localRankSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(localRankIdPtr == nullptr, OPS_LOG_E(nodeName, "localRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_COMM_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
|
||||
MAX_COMM_WORLD_SIZE,
|
||||
*rankSizePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*rankIdPtr < 0) || (*rankIdPtr >= *rankSizePtr),
|
||||
OPS_LOG_E(nodeName, "rankId is invalid, only support [0, %ld), but got rankId=%ld.", *rankSizePtr, *rankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*sendCountPtr <= 0),
|
||||
OPS_LOG_E(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%ld.", *sendCountPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numTokenPtr <= 0),
|
||||
OPS_LOG_E(nodeName, "numTokenPtr is invalid, only support > 0, but got numTokenPtr=%ld.", *numTokenPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
commGroup = std::string(commGroupPtr);
|
||||
tilingData.notifyDispatchInfo.rankSize = static_cast<uint32_t>(*rankSizePtr);
|
||||
tilingData.notifyDispatchInfo.rankId = static_cast<uint32_t>(*rankIdPtr);
|
||||
tilingData.notifyDispatchInfo.localRankSize = static_cast<uint32_t>(*localRankSizePtr);
|
||||
tilingData.notifyDispatchInfo.localRankId = static_cast<uint32_t>(*localRankIdPtr);
|
||||
tilingData.notifyDispatchInfo.sendCount = static_cast<uint32_t>(*sendCountPtr);
|
||||
tilingData.notifyDispatchInfo.numTokens = static_cast<uint32_t>(*numTokenPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context,
|
||||
NotifyDispatchTilingData *tiling, const std::string commGroup)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "NotifyDispatch commGroup = %s", commGroup.c_str());
|
||||
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(commGroup, opType1, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto sendData = context->GetInputDesc(INPUT_SEND_DATA_INDEX);
|
||||
OPS_CHECK(sendData == nullptr, OPS_LOG_E(nodeName, "sendData is null."), return false);
|
||||
OPS_CHECK((sendData->GetDataType() != ge::DT_BF16) && (sendData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(sendData->GetDataType() != ge::DT_FLOAT) && (sendData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"sendData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(sendData->GetDataType())),
|
||||
return false);
|
||||
uint64_t dataSize;
|
||||
if ((sendData->GetDataType() == ge::DT_BF16) || (sendData->GetDataType() == ge::DT_FLOAT16)) {
|
||||
dataSize = 2;
|
||||
} else {
|
||||
dataSize = 4;
|
||||
}
|
||||
auto tokenPerExpertData = context->GetInputDesc(INPUT_TOKEN_PER_EXPERT_INDEX);
|
||||
OPS_CHECK(tokenPerExpertData == nullptr, OPS_LOG_E(nodeName, "tokenPerExpertData is null."), return false);
|
||||
OPS_CHECK((tokenPerExpertData->GetDataType() != ge::DT_BF16) && (tokenPerExpertData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(tokenPerExpertData->GetDataType() != ge::DT_FLOAT) && (tokenPerExpertData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tokenPerExpertData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(tokenPerExpertData->GetDataType())),
|
||||
return false);
|
||||
|
||||
auto sendDataOffset = context->GetInputDesc(OUTPUT_SEND_DATA_OFFSET_INDEX);
|
||||
OPS_CHECK(sendDataOffset == nullptr, OPS_LOG_E(nodeName, "sendDataOffset is null."), return false);
|
||||
OPS_CHECK((sendDataOffset->GetDataType() != ge::DT_BF16) && (sendDataOffset->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(sendDataOffset->GetDataType() != ge::DT_FLOAT) && (sendDataOffset->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"sendDataOffset datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(sendDataOffset->GetDataType())),
|
||||
return false);
|
||||
|
||||
auto recvData = context->GetInputDesc(OUTPUT_RECV_DATA_INDEX);
|
||||
OPS_CHECK(recvData == nullptr, OPS_LOG_E(nodeName, "recvData is null."), return false);
|
||||
OPS_CHECK((recvData->GetDataType() != ge::DT_BF16) && (recvData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(recvData->GetDataType() != ge::DT_FLOAT) && (recvData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"recvData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(recvData->GetDataType())),
|
||||
return false);
|
||||
|
||||
// Verify the size of the win area
|
||||
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
|
||||
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
|
||||
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount;
|
||||
if (actualSize > maxWindowSize) {
|
||||
OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckTensor(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus NotifyDispatchTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string commGroup = "";
|
||||
OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func.");
|
||||
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, commGroup) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, commGroup);
|
||||
|
||||
int tilingKey = TILING_KEY_INT;
|
||||
auto sendDtype = context->GetInputDesc(0)->GetDataType();
|
||||
if (sendDtype == ge::DT_FLOAT16) {
|
||||
tilingKey = TILING_KEY_FLOAT16;
|
||||
} else if (sendDtype == ge::DT_BF16) {
|
||||
tilingKey = TILING_KEY_BFLOAT16;
|
||||
} else if (sendDtype == ge::DT_FLOAT) {
|
||||
tilingKey = TILING_KEY_FLOAT;
|
||||
}
|
||||
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
|
||||
|
||||
if (socVersion == "Ascend910B") {
|
||||
tilingKey = tilingKey + TILING_KEY_A2_TYPE;
|
||||
}
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t blockDim;
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
blockDim = aivNum;
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData->notifyDispatchInfo.totalUbSize = ubSize;
|
||||
tilingData->notifyDispatchInfo.aivNum = aivNum;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus NotifyDispatchTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = NotifyDispatchTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct NotifyDispatchCompileInfo {};
|
||||
ge::graphStatus TilingParseForNotifyDispatch(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(NotifyDispatch)
|
||||
.Tiling(NotifyDispatchTilingFunc)
|
||||
.TilingParse<NotifyDispatchCompileInfo>(TilingParseForNotifyDispatch);
|
||||
} // namespace optiling
|
||||
57
csrc/notify_dispatch/op_kernel/notify_dispatch.cpp
Normal file
57
csrc/notify_dispatch/op_kernel/notify_dispatch.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "notify_dispatch.h"
|
||||
#include "notify_dispatch_tiling.h"
|
||||
|
||||
#define TILING_KEY_FLOAT16 20
|
||||
#define TILING_KEY_BFLOAT16 21
|
||||
#define TILING_KEY_FLOAT 22
|
||||
#define TILING_KEY_INT 23
|
||||
|
||||
#define KERNEL_USE_WORKSPACE (1 * 1024 * 1024)
|
||||
|
||||
extern "C" __global__ __aicore__ void notify_dispatch(
|
||||
GM_ADDR sendData, GM_ADDR tokenPerExpertData, GM_ADDR sendDataOffset, GM_ADDR recvData, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(NotifyDispatchTilingData);
|
||||
GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tiling);
|
||||
|
||||
// hcomm will set magic later in init
|
||||
uint32_t magic = 1;
|
||||
GM_ADDR commArgs = nullptr;
|
||||
|
||||
int localRank = tilingData.notifyDispatchInfo.localRankId;
|
||||
int localRankSize = tilingData.notifyDispatchInfo.localRankSize;
|
||||
int rank = tilingData.notifyDispatchInfo.rankId;
|
||||
int rankSize = tilingData.notifyDispatchInfo.rankSize;
|
||||
int64_t len = tilingData.notifyDispatchInfo.sendCount;
|
||||
int64_t numTokens = tilingData.notifyDispatchInfo.numTokens;
|
||||
|
||||
GM_ADDR sendDataInput = sendData;
|
||||
GM_ADDR tokenPerExpertDataInput = tokenPerExpertData;
|
||||
GM_ADDR sendDataOffsetOutput = sendDataOffset;
|
||||
GM_ADDR recvDataOutput = recvData;
|
||||
|
||||
// fill in unused args
|
||||
uint32_t extraFlag = 0;
|
||||
GM_ADDR scale = nullptr;
|
||||
int root = 0;
|
||||
int op = 0;
|
||||
int cycleCount = 0;
|
||||
int64_t scaleCount = 0;
|
||||
GM_ADDR offset = nullptr;
|
||||
int blockNum = GetBlockNum();
|
||||
|
||||
if (TILING_KEY_IS(TILING_KEY_FLOAT16)) {
|
||||
NotifyDispatch<float16_t> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_FLOAT)) {
|
||||
NotifyDispatch<float> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_INT)) {
|
||||
NotifyDispatch<int> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
}
|
||||
}
|
||||
495
csrc/notify_dispatch/op_kernel/notify_dispatch.h
Normal file
495
csrc/notify_dispatch/op_kernel/notify_dispatch.h
Normal file
@@ -0,0 +1,495 @@
|
||||
#ifndef NOTIFY_DISPATCH_H
|
||||
#define NOTIFY_DISPATCH_H
|
||||
|
||||
#include <climits>
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "../common/comm_args.h"
|
||||
#include "../common/data_copy.h"
|
||||
#include "../common/sync_collectives.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
#define KERNELS_ARGS_FUN_ALL2ALL() \
|
||||
GM_ADDR sendDataInput, GM_ADDR tokenPerExpertDataInput, GM_ADDR sendDataOffsetOutput, GM_ADDR recvDataOutput, \
|
||||
int64_t len, int64_t numTokens, int op, int root, int cycleCount, GM_ADDR scale, int64_t scaleCount, \
|
||||
GM_ADDR offset, int localRank, int localRankSize, GM_ADDR commArgs, int magic
|
||||
|
||||
#define KERNELS_ARGS_CALL_ALL2ALL() \
|
||||
sendDataInput, tokenPerExpertDataInput, sendDataOffsetOutput, recvDataOutput, len, numTokens, op, root, \
|
||||
cycleCount, scale, scaleCount, offset, localRank, localRankSize, commArgs, magic
|
||||
|
||||
template <typename T>
|
||||
class NotifyDispatch {
|
||||
constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; // Invalid rank
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_X = 24; // Maximum number of cores provided by the producer stage
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_Y = 16; // Maximum number of cores provided by the consumer stage
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_Z = 16; // Maximum number of cores provided by the consumer stage 2
|
||||
constexpr static int64_t SHARE_QUE_DEPTH = 1; // Depth of a single shared queue
|
||||
constexpr static int64_t RANK_NUM_PER_NODE = 16;
|
||||
constexpr static int64_t SIO_NUM = 2; // Depth of a single shared queue
|
||||
constexpr static int64_t MAX_CORE_NUM = 48;
|
||||
constexpr static int64_t MAX_RANK_PER_CORE = 8;
|
||||
constexpr static int64_t MULTI_RANK_SIZE = 48;
|
||||
constexpr static int64_t MAX_BUFFER_NUMBER = 10;
|
||||
|
||||
constexpr static int64_t IDLER_CORE = 0; // Idle core
|
||||
constexpr static int64_t PRODUCER_CORE = 1; // Producer group, responsible for writing data to shared memory, input->share, or share->share
|
||||
constexpr static int64_t CONSUMER_CORE = 2; // Consumer group, responsible for reading data from shared memory, share->output
|
||||
constexpr static int64_t CONSUMER_CORE2 = 3;
|
||||
|
||||
public:
|
||||
__aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag)
|
||||
: rank(rank), rankSize(rankSize), extraFlag(extraFlag)
|
||||
{}
|
||||
|
||||
__aicore__ inline void Init(KERNELS_ARGS_FUN_ALL2ALL())
|
||||
{
|
||||
InitSmallFullMesh(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
nodeNum = rankSize / localRankSize;
|
||||
localRankId = rank % localRankSize;
|
||||
localNodeId = rank / localRankSize;
|
||||
perNodeDataNum = GetDataCount(len, nodeNum); // 128K/4 = 32K
|
||||
perRankDataNum = GetDataCount(len, rankSize); // 128K/64 = 2K
|
||||
|
||||
tokenPerExpertDataAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
sendDataOffsetAlignLen = Ceil(numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
sendDataAlignLen = Ceil(numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
|
||||
// Initialize core grouping
|
||||
InitCoreGroup();
|
||||
// Initialize data slicing
|
||||
InitDataSlice();
|
||||
|
||||
this->sendDataInput = (__gm__ T *)sendDataInput;
|
||||
this->tokenPerExpertDataInput = (__gm__ int32_t *)tokenPerExpertDataInput;
|
||||
this->sendDataOffsetOutput = (__gm__ T *)sendDataOffsetOutput;
|
||||
this->recvDataOutput = (__gm__ T *)recvDataOutput;
|
||||
sendDataInputGt.SetGlobalBuffer((__gm__ T *)sendDataInput);
|
||||
tokenPerExpertDataInputGt.SetGlobalBuffer((__gm__ int32_t *)tokenPerExpertDataInput);
|
||||
sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput);
|
||||
recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput);
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
if (blockIdx < 1) {
|
||||
AssembleSendData();
|
||||
}
|
||||
SyncAll<true>();
|
||||
if (blockIdx < coreNumPerStageX) {
|
||||
InputToShareSlice();
|
||||
}
|
||||
if (blockIdx < coreNumPerStageY) {
|
||||
ShareToShareSlice();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void InitCoreGroup()
|
||||
{
|
||||
coreNumPerStageY = MAX_CORE_NUM;
|
||||
coreNumPerStageX = MAX_CORE_NUM;
|
||||
rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM;
|
||||
}
|
||||
|
||||
__aicore__ inline void InitDataSlice()
|
||||
{
|
||||
// The producer is responsible for moving the input data of this rank to shared memory, input-->share
|
||||
if (blockIdx < coreNumPerStageX) {
|
||||
ProducerDataSlice();
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProducerDataSlice()
|
||||
{
|
||||
// The ipcQue responsible for the current core
|
||||
writeGt.SetGlobalBuffer((__gm__ T *)(shareAddrs[rank] + IPC_DATA_OFFSET));
|
||||
}
|
||||
|
||||
__aicore__ inline void AssembleSendData()
|
||||
{
|
||||
pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen);
|
||||
pipe.InitBuffer(sendDataBuf, sendDataAlignLen);
|
||||
pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen);
|
||||
|
||||
__ubuf__ int32_t *tokenPerExpertUB = (__ubuf__ int32_t *)get_imm(96);
|
||||
CpGM2UB(tokenPerExpertUB, (__gm__ int32_t *)tokenPerExpertDataInputGt.GetPhyAddr(), tokenPerExpertDataAlignLen);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
|
||||
__ubuf__ T *sendDataOffsetUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen);
|
||||
__ubuf__ T *sendDataUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen + sendDataOffsetAlignLen);
|
||||
|
||||
int prefixSum = 0;
|
||||
for (int i = 0; i < numExperts; ++i) {
|
||||
int numTokensExpert = tokenPerExpertUB[i];
|
||||
sendDataUB[i * sendPerGroup] = numTokensExpert;
|
||||
sendDataUB[i * sendPerGroup + 1] = prefixSum;
|
||||
sendDataUB[i * sendPerGroup + 2] = numTokens;
|
||||
sendDataOffsetUB[i] = prefixSum;
|
||||
|
||||
prefixSum += numTokensExpert;
|
||||
}
|
||||
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
|
||||
CpUB2GM((__gm__ T *)sendDataInputGt.GetPhyAddr(), sendDataUB, sendDataAlignLen);
|
||||
CpUB2GM((__gm__ T *)sendDataOffsetOutputGt.GetPhyAddr(), sendDataOffsetUB, sendDataOffsetAlignLen);
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
// copy input to other rank share
|
||||
__aicore__ inline void InputToShareSlice()
|
||||
{
|
||||
__ubuf__ int64_t *inputUB = (__ubuf__ int64_t *)get_imm(0);
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
readGt = sendDataInputGt[copyOffset * perRankDataNum];
|
||||
CpGM2GMPingPong<T>(
|
||||
copyLen * perRankDataNum * sizeof(T), readGt, writeGt[copyOffset * perRankDataNum], COPYONLY);
|
||||
int64_t v = MergeMagicWithValue(magic, 1);
|
||||
*inputUB = v;
|
||||
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
CpUB2GM((__gm__ int64_t *)(shareAddrs[i]) + rank * FLAG_UNIT_INT_NUM, inputUB, sizeof(int64_t));
|
||||
}
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value)
|
||||
{
|
||||
// magic as the high part, eventID as the low part, combined into a value for comparison
|
||||
return (static_cast<int64_t>(static_cast<uint32_t>(magic)) << MAGIC_OFFSET) | static_cast<int64_t>(value);
|
||||
}
|
||||
|
||||
__aicore__ inline void ShareToShareSlice()
|
||||
{
|
||||
__ubuf__ T *inputUB = (__ubuf__ T *)get_imm(96);
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
int checkRank[MAX_RANK_PER_CORE];
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
checkRank[i - copyOffset] = i + rank % copyLen;
|
||||
if (checkRank[i - copyOffset] >= copyOffset + copyLen) {
|
||||
checkRank[i - copyOffset] -= copyLen;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < copyLen; i++) {
|
||||
readGt1[i].SetGlobalBuffer((__gm__ T *)(shareAddrs[checkRank[i]] + IPC_DATA_OFFSET));
|
||||
}
|
||||
sync.WaitSyncFlag(magic, 1, copyOffset, rank, copyLen);
|
||||
for (int i = 0; i < copyLen; i++) {
|
||||
CpGM2GMPingPong<T>(perRankDataNum * sizeof(T),
|
||||
readGt1[i][rank * perRankDataNum],
|
||||
recvDataOutputGt[checkRank[i] * perRankDataNum],
|
||||
COPYONLY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum);
|
||||
__aicore__ inline GM_ADDR GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx);
|
||||
__aicore__ inline int32_t GetMagicValue(void);
|
||||
FORCE_INLINE_AICORE void InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL());
|
||||
template <typename F>
|
||||
FORCE_INLINE_AICORE void SetAtomic(int op);
|
||||
FORCE_INLINE_AICORE void UnsetAtomic(int op);
|
||||
template<HardEvent eventType>
|
||||
FORCE_INLINE_AICORE void SetWaitEvent(event_t eventId);
|
||||
template <typename K, typename U = K>
|
||||
FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor<U>& sendDataInputGt,
|
||||
const GlobalTensor<K>& recvDataOutputGT, int op);
|
||||
|
||||
GlobalTensor<T> sendDataInputGt;
|
||||
GlobalTensor<int> tokenPerExpertDataInputGt;
|
||||
GlobalTensor<T> sendDataOffsetOutputGt;
|
||||
GlobalTensor<T> recvDataOutputGt;
|
||||
GlobalTensor<T> readGt;
|
||||
GlobalTensor<T> writeGt;
|
||||
GlobalTensor<T> readGt1[MAX_BUFFER_NUMBER];
|
||||
GlobalTensor<T> ipcGT;
|
||||
GlobalTensor<int64_t> sendCountMatrixGm;
|
||||
__gm__ T *sendDataInput;
|
||||
__gm__ int *tokenPerExpertDataInput;
|
||||
__gm__ T *sendDataOffsetOutput;
|
||||
__gm__ T *recvDataOutput;
|
||||
int64_t isPad = 0;
|
||||
int64_t maxSliceNum;
|
||||
int64_t revLen = 0;
|
||||
int64_t sendLen = 0;
|
||||
int64_t sliceLen;
|
||||
int64_t perNodeDataNum;
|
||||
int64_t perRankDataNum;
|
||||
int64_t curRankDataNum;
|
||||
int64_t sendOffset[MULTI_RANK_SIZE];
|
||||
int64_t revOffset[MULTI_RANK_SIZE];
|
||||
int64_t inputDataLen[MULTI_RANK_SIZE];
|
||||
|
||||
int64_t nodeNum;
|
||||
int64_t localRankId;
|
||||
int64_t localNodeId;
|
||||
int64_t targetNode;
|
||||
int64_t targetLocalRankIds[2];
|
||||
int64_t queLen;
|
||||
int64_t queSize;
|
||||
int64_t coreNumPerStageX; // Number of cores used per stage
|
||||
int64_t coreNumPerStageY; // Number of cores used per stage
|
||||
int64_t coreNumPerStageZ; // Number of cores used per stage
|
||||
int64_t flagNumPerStage; // Number of synchronization flags used per stage
|
||||
int64_t coreNumPerNode; // Number of cores allocated per node
|
||||
int64_t coreNumPerRank; // Number of cores allocated per rank
|
||||
int64_t rankNumPerCore; // Number of ranks responsible per core
|
||||
int64_t coreGroup; // Functional group of the current core
|
||||
int64_t targetRank[MULTI_RANK_SIZE]; // Ranks responsible by the current core
|
||||
int64_t targetRankX;
|
||||
int64_t targetRankY;
|
||||
|
||||
int64_t queElemLen; // Size of each element in the shared memory queue (in terms of T)
|
||||
|
||||
int64_t copyLen; // Length of the current data slice being copied (in terms of T)
|
||||
|
||||
// for coll
|
||||
int rank;
|
||||
int rankSize;
|
||||
int localRank = 0;
|
||||
int localRankSize = 0;
|
||||
int xRankSize = 0;
|
||||
int yRankSize = 0;
|
||||
int xRankIdx = 0;
|
||||
int yRankIdx = 0;
|
||||
uint32_t extraFlag;
|
||||
int numTokens;
|
||||
int sendPerGroup = 3;
|
||||
int root;
|
||||
int64_t len;
|
||||
int64_t numExperts;
|
||||
int64_t magic;
|
||||
int64_t blockIdx; // Index of the current aicore
|
||||
int64_t blockNum; // Total number of aicores for the current rank
|
||||
int32_t numRanks;
|
||||
int64_t timeout;
|
||||
uint16_t *rootRanks;
|
||||
GM_ADDR scale;
|
||||
GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses
|
||||
__gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
|
||||
GlobalTensor<int64_t> dfx;
|
||||
TPipe pipe;
|
||||
TBuf<QuePosition::VECCALC> tBuf;
|
||||
TBuf<> tokenPerExpertDataBuf;
|
||||
TBuf<> sendDataOffsetBuf;
|
||||
TBuf<> sendDataBuf;
|
||||
|
||||
uint32_t sendDataAlignLen{0};
|
||||
uint32_t tokenPerExpertDataAlignLen{0};
|
||||
uint32_t sendDataOffsetAlignLen{0};
|
||||
|
||||
SyncCollectives sync;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE int64_t NotifyDispatch<T>::GetDataCount(const int64_t dataLen, const int64_t useBlockNum)
|
||||
{
|
||||
return dataLen / useBlockNum;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline GM_ADDR NotifyDispatch<T>::GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx)
|
||||
{
|
||||
uint32_t curRankId = rank;
|
||||
#ifdef OPT_RANK_OFFSET
|
||||
#pragma message("use rank offset")
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + rankId * OPT_RANK_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) +
|
||||
rankId * OPT_RANK_OFFSET;
|
||||
#else
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn);
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Assign values to winContext_[COMM_EP_IDX] and blockIdx before calling
|
||||
template <typename T>
|
||||
__aicore__ inline int32_t NotifyDispatch<T>::GetMagicValue(void)
|
||||
{
|
||||
int32_t magic = 0;
|
||||
GlobalTensor<int32_t> selfDataStatusTensor;
|
||||
GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp);
|
||||
selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET));
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfDataStatusTensor[blockIdx * UB_ALIGN_SIZE]);
|
||||
magic = selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE);
|
||||
if (magic <= 0) {
|
||||
magic = 1;
|
||||
}
|
||||
selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE) = magic + 1;
|
||||
return magic;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL())
|
||||
{
|
||||
this->root = root;
|
||||
this->len = len;
|
||||
this->numExperts = len / sendPerGroup;
|
||||
this->numTokens = numTokens;
|
||||
this->scale = scale;
|
||||
this->localRank = localRank;
|
||||
this->localRankSize = localRankSize;
|
||||
this->xRankSize = localRankSize;
|
||||
this->yRankSize = rankSize / localRankSize;
|
||||
this->xRankIdx = rank % localRankSize;
|
||||
this->yRankIdx = rank / localRankSize;
|
||||
blockIdx = GetBlockIdx();
|
||||
blockNum = GetBlockNum();
|
||||
uint8_t ctxIdx;
|
||||
|
||||
winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
this->magic = GetMagicValue();
|
||||
ctxIdx = COMM_EP_IDX;
|
||||
|
||||
shareAddrs[rank] = GetWindAddrByRankId(rank, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
|
||||
int64_t rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM;
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
int64_t copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
shareAddrs[i] = GetWindAddrByRankId(i, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
}
|
||||
}
|
||||
|
||||
// When the number of cores is more than the number of ranks, each core is responsible for fetching data from a specified rank
|
||||
int coreNumPerRank = blockNum / rankSize; // Calculate the number of cores assigned to read for each rank, e.g., 48 cores 4 ranks, each rank is assigned 12 cores
|
||||
int maxCore = coreNumPerRank * rankSize; // Calculate the maximum number of cores that can be used for reading, cores exceeding this number will not take action
|
||||
if (blockIdx < maxCore) {
|
||||
int readRank = blockIdx / coreNumPerRank; // Calculate the rank to be read based on the block, 48 cores divided into 4 groups
|
||||
shareAddrs[readRank] = GetWindAddrByRankId(readRank, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
}
|
||||
|
||||
pipe.InitBuffer(tBuf, UB_SINGLE_TOTAL_SIZE_MAX);
|
||||
|
||||
sync.Init(rank, rankSize, shareAddrs, tBuf);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy data from GM to GM with ping-pong method.
|
||||
* @tparam dataSizeRemain The remaining size of data to be copied.
|
||||
* @tparam K The type of output data.
|
||||
* @tparam U The type of input data.
|
||||
* @param sendDataInputGt The global tensor of send data.
|
||||
* @param recvDataOutputGT The global tensor of recv data.
|
||||
* @param op The operation to be performed during the copy.
|
||||
* @details This function copies data from global memory to global memory using a ping-pong method.
|
||||
* It first checks if the input and output types are the same. If they are, it uses a single buffer.
|
||||
* If they are not, it divides the buffer according to the size ratio of the types and aligns it to 32 bytes.
|
||||
* Then, it sets the atomic operation, waits for the flags, and performs the copy operation.
|
||||
*/
|
||||
template <typename T>
|
||||
template <typename K, typename U>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor<U>& sendDataInputGt,
|
||||
const GlobalTensor<K>& recvDataOutputGT, int op)
|
||||
{
|
||||
// General case (U = K), input/output are the same, share one UB
|
||||
// Only when conversion is needed (U->K), UB will be divided into two parts according to the ratio of sizeof(U):sizeof(K) and aligned to 32 bytes
|
||||
constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX;
|
||||
constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(K) + sizeof(U)) / UB_ALIGN_SIZE * UB_ALIGN_SIZE;
|
||||
constexpr int32_t inputUbBlockSize = std::is_same_v<K, U> ? ubBlockSize : ubAlignNum * sizeof(U);
|
||||
constexpr int32_t outputUbBlockSize = std::is_same_v<K, U> ? ubBlockSize : ubAlignNum * sizeof(K);
|
||||
|
||||
__gm__ U *input = const_cast<__gm__ U *>(sendDataInputGt.GetPhyAddr());
|
||||
__gm__ K *output = const_cast<__gm__ K *>(recvDataOutputGT.GetPhyAddr());
|
||||
__ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET), (__ubuf__ U*)(UB_MID_OFFSET)};
|
||||
__ubuf__ K* outputUB[2] = {(__ubuf__ K*)inputUB[0], (__ubuf__ K*)inputUB[1]};
|
||||
if constexpr (!std::is_same_v<K, U>) {
|
||||
outputUB[0] = (__ubuf__ K*)(inputUB[0] + inputUbBlockSize / sizeof(U));
|
||||
outputUB[1] = (__ubuf__ K*)(inputUB[1] + inputUbBlockSize / sizeof(U));
|
||||
}
|
||||
int inputOffsetNum = 0;
|
||||
int outputOffsetNum = 0;
|
||||
if (dataSizeRemain <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
SetAtomic<K>(op);
|
||||
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0); // MTE2 waits for MTE3
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1); // MTE2 waits for MTE3
|
||||
for (int64_t i = 0; dataSizeRemain > 0; i++) {
|
||||
// size and dataSizeRemain both refer to the output size
|
||||
uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain;
|
||||
event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1;
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(eventId);
|
||||
CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(K) * sizeof(U));
|
||||
if constexpr (!std::is_same_v<K, U>) {
|
||||
SetWaitEvent<HardEvent::MTE2_V>(eventId);
|
||||
CastImpl((i & 1) ? outputUB[0] : outputUB[1], (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE,
|
||||
size / sizeof(K));
|
||||
SetWaitEvent<HardEvent::V_MTE3>(eventId);
|
||||
}
|
||||
AscendC::SetFlag<HardEvent::MTE2_MTE3>(eventId);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_MTE3>(eventId);
|
||||
CpUB2GM(output + outputOffsetNum, (i & 1) ? outputUB[0] : outputUB[1], size);
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(eventId);
|
||||
|
||||
dataSizeRemain -= size;
|
||||
inputOffsetNum += (size / sizeof(K));
|
||||
outputOffsetNum += (size / sizeof(K));
|
||||
}
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0); // MTE2 waits for MTE3
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1); // MTE2 waits for MTE3
|
||||
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID3); // Scalar waits for MTE3
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID3);
|
||||
|
||||
UnsetAtomic(op);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename F>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::SetAtomic(int op)
|
||||
{
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
if (op != -1) {
|
||||
#ifdef __DAV_C220_VEC__
|
||||
SetAtomicOpType<F>(op);
|
||||
#endif
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::UnsetAtomic(int op)
|
||||
{
|
||||
if (op != -1) {
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<HardEvent eventType>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::SetWaitEvent(event_t eventId)
|
||||
{
|
||||
AscendC::SetFlag<eventType>(eventId);
|
||||
AscendC::WaitFlag<eventType>(eventId);
|
||||
}
|
||||
|
||||
#endif // NOTIFY_DISPATCH_H
|
||||
23
csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h
Normal file
23
csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef NOTIFY_DISPATCH_TILING_H
|
||||
#define NOTIFY_DISPATCH_TILING_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
struct NotifyDispatchInfo {
|
||||
uint32_t rankSize;
|
||||
uint32_t rankId;
|
||||
uint32_t localRankSize;
|
||||
uint32_t localRankId;
|
||||
uint32_t sendCount;
|
||||
uint32_t numTokens;
|
||||
uint32_t aivNum;
|
||||
uint64_t totalUbSize;
|
||||
};
|
||||
|
||||
struct NotifyDispatchTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
NotifyDispatchInfo notifyDispatchInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <torch/torch.h>
|
||||
#include <torch_npu/csrc/core/npu/NPUStream.h>
|
||||
#include <torch_npu/csrc/framework/OpCommand.h>
|
||||
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
|
||||
#include "torch_npu/csrc/core/npu/NPUGuard.h"
|
||||
#include <torch_npu/csrc/npu/Module.h>
|
||||
#include "acl/acl.h"
|
||||
@@ -808,6 +809,246 @@ at::Tensor npu_sparse_flash_attention(
|
||||
return output;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts,
|
||||
int64_t num_ranks) {
|
||||
TORCH_BIND_ASSERT(topk_idx.dim() == 2);
|
||||
TORCH_BIND_ASSERT(topk_idx.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_experts > 0);
|
||||
|
||||
const int num_tokens = topk_idx.size(0);
|
||||
const int num_topk = topk_idx.size(1);
|
||||
|
||||
auto device = topk_idx.device();
|
||||
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
|
||||
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
|
||||
auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
|
||||
|
||||
EXEC_NPU_CMD(aclnnDispatchLayout,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_ranks,
|
||||
num_experts,
|
||||
num_topk,
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank);
|
||||
|
||||
auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool);
|
||||
|
||||
return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> dispatch_prefill(
|
||||
const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
|
||||
const at::Tensor& num_tokens_per_rank, const at::Tensor& is_token_in_rank, at::Tensor& num_tokens_per_expert,
|
||||
int64_t num_worst_tokens, c10::string_view groupEp, int64_t rank, int64_t num_ranks) {
|
||||
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char* group_ep_ptr = &group_ep_chrs[0];
|
||||
at::Tensor new_x = x;
|
||||
|
||||
// Type checks
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.scalar_type() == at::kBool);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.scalar_type() == at::kInt);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.scalar_type() == at::kInt);
|
||||
|
||||
// Shape and contiguous checks
|
||||
TORCH_BIND_ASSERT(new_x.dim() == 2 and new_x.is_contiguous());
|
||||
// TORCH_BIND_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.size(0) == new_x.size(0) and is_token_in_rank.size(1) == num_ranks);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.dim() == 1 and num_tokens_per_expert.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.size(0) % num_ranks == 0);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.dim() == 1 and num_tokens_per_rank.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.size(0) == num_ranks);
|
||||
|
||||
auto num_tokens = static_cast<int>(new_x.size(0));
|
||||
auto hidden = static_cast<int>(new_x.size(1));
|
||||
auto num_experts = static_cast<int64_t>(num_tokens_per_expert.size(0));
|
||||
auto num_local_experts = static_cast<int>(num_experts / num_ranks);
|
||||
|
||||
// Top-k checks
|
||||
int num_topk = 0;
|
||||
num_topk = static_cast<int>(topk_idx.size(1));
|
||||
TORCH_BIND_ASSERT(num_experts > 0);
|
||||
TORCH_BIND_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
|
||||
TORCH_BIND_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens == topk_idx.size(0));
|
||||
TORCH_BIND_ASSERT(num_topk == topk_weights.size(1));
|
||||
TORCH_BIND_ASSERT(topk_weights.scalar_type() == at::kFloat);
|
||||
|
||||
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
|
||||
|
||||
auto send_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
|
||||
int64_t send_count = send_per_group * num_local_experts * num_ranks;
|
||||
|
||||
auto send_data_offset = at::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
|
||||
at::Tensor recv_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
|
||||
|
||||
int64_t local_rank_size = num_ranks;
|
||||
int64_t local_rank_id = rank % local_rank_size;
|
||||
|
||||
EXEC_NPU_CMD(aclnnNotifyDispatch,
|
||||
send_data,
|
||||
num_tokens_per_expert,
|
||||
send_count,
|
||||
num_tokens,
|
||||
group_ep_ptr, // commGroup
|
||||
num_ranks, // rankSize
|
||||
rank, // rankId
|
||||
local_rank_size,
|
||||
local_rank_id,
|
||||
send_data_offset,
|
||||
recv_data);
|
||||
|
||||
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||
std::vector<int32_t> local_expert_acc(num_experts, 0);
|
||||
auto send_token_idx_cpu = at::empty({num_tokens, num_topk}, options_cpu);
|
||||
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
|
||||
|
||||
auto topk_idx_cpu = topk_idx.to(at::kCPU);
|
||||
auto topk_idx_ptr = topk_idx_cpu.data_ptr<int64_t>();
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
for (int j = 0; j < num_topk; ++j) {
|
||||
int64_t expert_idx = topk_idx_ptr[i * num_topk + j];
|
||||
if (expert_idx >= 0) {
|
||||
int32_t cnt = local_expert_acc[expert_idx];
|
||||
send_token_idx_ptr[i * num_topk + j] = cnt;
|
||||
local_expert_acc[expert_idx]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_BIND_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
|
||||
TORCH_BIND_ASSERT(recv_data.size(0) % num_experts == 0);
|
||||
at::Tensor recv_offset_cpu = at::empty({num_experts}, options_cpu);
|
||||
at::Tensor recv_count_cpu = at::empty({num_experts}, options_cpu);
|
||||
auto recv_data_cpu = recv_data.to(at::kCPU);
|
||||
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
|
||||
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
|
||||
auto recv_offset_ptr = recv_offset_cpu.data_ptr<int>();
|
||||
int64_t total_recv_tokens = 0;
|
||||
int64_t num_max_dispatch_tokens_per_rank = 0;
|
||||
std::vector<int64_t> num_recv_tokens_per_expert_list;
|
||||
|
||||
for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) {
|
||||
int64_t local_expert_recv_tokens = 0;
|
||||
for (int64_t src_rank = 0; src_rank < num_ranks; ++src_rank) {
|
||||
int64_t index = local_e * num_ranks + src_rank;
|
||||
int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e);
|
||||
|
||||
int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for
|
||||
// this global_expert
|
||||
int recv_off = recv_data_ptr[pair_idx + 1]; // offset in that src_rank's window
|
||||
int64_t send_num_tokens = recv_data_ptr[pair_idx + 2]; // all bs from rank
|
||||
|
||||
total_recv_tokens += recv_cnt;
|
||||
recv_count_ptr[index] = total_recv_tokens;
|
||||
recv_offset_ptr[index] = recv_off;
|
||||
num_max_dispatch_tokens_per_rank = std::max(num_max_dispatch_tokens_per_rank, send_num_tokens);
|
||||
|
||||
local_expert_recv_tokens += recv_cnt;
|
||||
}
|
||||
num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens);
|
||||
}
|
||||
auto option = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
|
||||
at::Tensor num_recv_tokens_per_expert = torch::from_blob(
|
||||
num_recv_tokens_per_expert_list.data(), {static_cast<int64_t>(num_recv_tokens_per_expert_list.size())}, option)
|
||||
.clone();
|
||||
|
||||
at::Tensor expert_ids = topk_idx.to(at::kInt);
|
||||
int64_t tp_size = 1;
|
||||
int64_t tp_rank = 0;
|
||||
int64_t quant_mode = 0;
|
||||
int64_t global_bs = static_cast<int64_t>(
|
||||
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)));
|
||||
|
||||
auto send_token_idx = send_token_idx_cpu.to(x.device());
|
||||
auto recv_offset = recv_offset_cpu.to(x.device());
|
||||
auto recv_count = recv_count_cpu.to(x.device());
|
||||
|
||||
int total_cnt = total_recv_tokens;
|
||||
if (total_cnt == 0) {
|
||||
total_cnt = 1;
|
||||
}
|
||||
auto expandx_out = at::empty({total_cnt, hidden}, x.options());
|
||||
auto dynamic_scales_out = at::empty({total_cnt}, at::dtype(at::kFloat).device(x.device()));
|
||||
auto expand_idx_out = at::empty({total_cnt * 3}, at::dtype(at::kInt).device(x.device()));
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeDispatchNormal,
|
||||
new_x,
|
||||
expert_ids,
|
||||
send_data_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
group_ep_ptr, // commGroup
|
||||
num_ranks, // rankSize
|
||||
rank, // rankId
|
||||
group_ep_ptr,
|
||||
tp_size,
|
||||
tp_rank,
|
||||
num_experts,
|
||||
quant_mode,
|
||||
global_bs,
|
||||
expandx_out,
|
||||
dynamic_scales_out,
|
||||
expand_idx_out);
|
||||
|
||||
// Return values
|
||||
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
|
||||
}
|
||||
|
||||
at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
|
||||
const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp,
|
||||
int64_t rank, int64_t num_ranks) {
|
||||
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char* group_ep_ptr = &group_ep_chrs[0];
|
||||
|
||||
TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous());
|
||||
at::Tensor recv_x = x;
|
||||
|
||||
at::Tensor topk_idx_p = topk_idx;
|
||||
|
||||
auto topk_idx_int32 = topk_idx_p.to(at::kInt);
|
||||
at::Tensor expand_ids = topk_idx_int32;
|
||||
at::Tensor token_src_info = src_idx;
|
||||
at::Tensor ep_send_counts = send_head;
|
||||
auto device = x.device();
|
||||
|
||||
const int num_tokens = topk_idx_p.size(0);
|
||||
const int num_topk = topk_idx_p.size(1);
|
||||
|
||||
int64_t hidden = static_cast<int>(recv_x.size(1));
|
||||
at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device));
|
||||
int64_t tp_world_size = 1;
|
||||
int64_t tp_rankId = 0;
|
||||
int64_t moe_expert_number = send_head.size(0);
|
||||
int64_t global_bs = topk_idx_p.size(0) * num_ranks;
|
||||
|
||||
// Combine data
|
||||
auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options());
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeCombineNormal,
|
||||
recv_x,
|
||||
token_src_info,
|
||||
ep_send_counts,
|
||||
topk_weights,
|
||||
tp_send_counts,
|
||||
group_ep_ptr,
|
||||
num_ranks,
|
||||
rank,
|
||||
group_ep_ptr,
|
||||
tp_world_size,
|
||||
tp_rankId,
|
||||
moe_expert_number,
|
||||
global_bs,
|
||||
combined_x);
|
||||
|
||||
return combined_x;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -921,4 +1162,25 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
|
||||
ops.def("get_dispatch_layout(Tensor topk_idx, int num_experts, int "
|
||||
"num_ranks) -> (Tensor num_tokens_per_rank, Tensor "
|
||||
"num_tokens_per_expert, Tensor is_token_in_rank_bool)");
|
||||
ops.impl("get_dispatch_layout", torch::kPrivateUse1,
|
||||
&vllm_ascend::get_dispatch_layout);
|
||||
|
||||
ops.def(
|
||||
"dispatch_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
|
||||
"Tensor num_tokens_per_rank, Tensor is_token_in_rank, Tensor "
|
||||
"num_tokens_per_expert, int num_worst_tokens, str groupEp, int rank, "
|
||||
"int num_ranks) -> (Tensor expandx_out, Tensor expand_idx_out, Tensor "
|
||||
"recv_count, Tensor num_recv_tokens_per_expert)");
|
||||
ops.impl("dispatch_prefill", torch::kPrivateUse1,
|
||||
&vllm_ascend::dispatch_prefill);
|
||||
|
||||
ops.def("combine_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
|
||||
"Tensor src_idx, Tensor send_head, str grouEp, int rank, int "
|
||||
"num_ranks) -> Tensor");
|
||||
ops.impl("combine_prefill", torch::kPrivateUse1,
|
||||
&vllm_ascend::combine_prefill);
|
||||
}
|
||||
|
||||
24
csrc/utils.h
24
csrc/utils.h
@@ -28,4 +28,28 @@
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
class TrochBindException : public std::exception
|
||||
{
|
||||
private:
|
||||
std::string message = {};
|
||||
|
||||
public:
|
||||
explicit TrochBindException(const char *name, const char *file, const int line, const std::string &error)
|
||||
{
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) +
|
||||
" error message or error code is '" + error + "'";
|
||||
}
|
||||
|
||||
const char *what() const noexcept override
|
||||
{
|
||||
return message.c_str();
|
||||
}
|
||||
};
|
||||
|
||||
#define TORCH_BIND_ASSERT(cond) \
|
||||
; \
|
||||
do { \
|
||||
if (not(cond)) { \
|
||||
throw TrochBindException("Assertion", __FILE__, __LINE__, #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
72
csrc/utils/inc/kernel/comm_args.h
Normal file
72
csrc/utils/inc/kernel/comm_args.h
Normal file
@@ -0,0 +1,72 @@
|
||||
#ifndef COMM_ARGS_H
|
||||
#define COMM_ARGS_H
|
||||
#include <cstdint>
|
||||
|
||||
#define FORCE_INLINE_AICORE __attribute__((always_inline)) inline __aicore__
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace Moe {
|
||||
constexpr int CAM_MAX_RANK_SIZE = 384; // Maximum number of NPU cards supported by the communication library
|
||||
|
||||
constexpr int64_t IPC_BUFF_MAX_SIZE = 100 * 1024 * 1024;
|
||||
constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // First 2MB as flag, then 100MB as data storage
|
||||
constexpr int64_t PING_PONG_SIZE = 2;
|
||||
constexpr int64_t UB_SINGLE_DMA_SIZE_MAX = 190 * 1024;
|
||||
constexpr int64_t SMALL_DATA_SIZE = 1 * 1024 * 1024;
|
||||
constexpr int64_t UB_SINGLE_PING_PONG_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX / 2;
|
||||
constexpr int UB_ALIGN_SIZE = 32;
|
||||
constexpr int64_t MAGIC_ALIGN_COUNT = UB_ALIGN_SIZE / sizeof(int32_t);
|
||||
|
||||
constexpr uint8_t COMM_NUM = 2; // Size of communication domain
|
||||
constexpr uint8_t COMM_EP_IDX = 0;
|
||||
constexpr uint8_t COMM_TP_IDX = 1;
|
||||
|
||||
constexpr int DFX_COUNT = 50;
|
||||
constexpr int64_t WAIT_SUCCESS = 112233445566;
|
||||
constexpr int64_t IPC_CHUNK_FLAG = 0; // Start offset for send recv, chunk flag region
|
||||
constexpr int64_t MAX_WAIT_ROUND_UNIT = 10 * 1000 * 1000; // Threshold for waiting to get Flag under normal conditions within the same SIO
|
||||
|
||||
constexpr static int32_t UB_HEAD_OFFSET = 96;
|
||||
constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + UB_ALIGN_SIZE;
|
||||
constexpr static int64_t UB_FLAG_SIZE = 2 * 1024;
|
||||
constexpr static int64_t MAX_CORE_NUM = 48;
|
||||
constexpr static uint64_t STATE_WIN_OFFSET = 900 * 1024;
|
||||
constexpr static int64_t COMPARE_ALIGN_SIZE = 256;
|
||||
|
||||
constexpr static int64_t UB_SINGLE_TOTAL_SIZE_MAX = 192 * 1024;
|
||||
constexpr static int64_t START_OFFSET_FOR_SHARE = 512;
|
||||
|
||||
enum Op : int {
|
||||
COPYONLY = -1,
|
||||
ADD = 0,
|
||||
MUL = 1,
|
||||
MAX = 2,
|
||||
MIN = 3
|
||||
};
|
||||
|
||||
struct CommArgs {
|
||||
int rank = 0; // attr rank_id, global rank
|
||||
int localRank = -1;
|
||||
int rankSize = 0; // global rank size
|
||||
int localRankSize = -1; // This parameter refers to the number of cards interconnected in fullmesh
|
||||
uint32_t extraFlag = 0; // 32 bit map, the specific meaning of each bit is above in this file
|
||||
int testFlag = 0;
|
||||
GM_ADDR peerMems[CAM_MAX_RANK_SIZE] = {}; // Buffer obtained from initialization, all allreduce is the same parameter
|
||||
/**
|
||||
* @param sendCountMatrix One-dimensional array with a size of rankSize*rankSize
|
||||
* eg: The value of sendCountMatrix[1] corresponds to the [0][1] of the two-dimensional array, indicating the number of data that card 0 needs to send to card 1
|
||||
*/
|
||||
int64_t sendCountMatrix[CAM_MAX_RANK_SIZE * CAM_MAX_RANK_SIZE] = {}; // for all2allvc
|
||||
int64_t sendCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv
|
||||
int64_t sdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv
|
||||
int64_t recvCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv
|
||||
int64_t rdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv
|
||||
int64_t batchSize;
|
||||
int64_t hiddenSize;
|
||||
int64_t topk;
|
||||
int64_t sharedExpertRankNum;
|
||||
int64_t expertNumPerRank;
|
||||
int64_t dfx[DFX_COUNT] = {};
|
||||
};
|
||||
}
|
||||
#endif // COMM_ARGS_H
|
||||
68
csrc/utils/inc/kernel/data_copy.h
Normal file
68
csrc/utils/inc/kernel/data_copy.h
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef CAM_DATACOPY_GM2GM_H
|
||||
#define CAM_DATACOPY_GM2GM_H
|
||||
#include <type_traits>
|
||||
#include "comm_args.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void SetAtomicOpType(int op)
|
||||
{
|
||||
switch (op) {
|
||||
case ADD:
|
||||
AscendC::SetAtomicAdd<T>();
|
||||
break;
|
||||
case MUL:
|
||||
// Ignore setting the atomic register when performing mul
|
||||
break;
|
||||
case MAX:
|
||||
AscendC::SetAtomicMax<T>();
|
||||
break;
|
||||
case MIN:
|
||||
AscendC::SetAtomicMin<T>();
|
||||
break;
|
||||
default:
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void CpUB2GM(__gm__ T *gmAddr, __ubuf__ T *ubAddr, uint32_t size)
|
||||
{
|
||||
LocalTensor<uint8_t> ubTensor;
|
||||
GlobalTensor<uint8_t> gmTensor;
|
||||
DataCopyExtParams dataCopyParams(1, size, 0, 0, 0);
|
||||
ubTensor.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ubTensor.address_.bufferAddr = reinterpret_cast<uint64_t>(ubAddr);
|
||||
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr));
|
||||
DataCopyPad(gmTensor, ubTensor, dataCopyParams);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void CpGM2UB(__ubuf__ T *ubAddr, __gm__ T *gmAddr, uint32_t size)
|
||||
{
|
||||
LocalTensor<uint8_t> ubTensor;
|
||||
GlobalTensor<uint8_t> gmTensor;
|
||||
DataCopyExtParams dataCopyParams(1, size, 0, 0, 0);
|
||||
ubTensor.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ubTensor.address_.bufferAddr = reinterpret_cast<uint64_t>(ubAddr);
|
||||
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr));
|
||||
DataCopyPadExtParams<uint8_t> padParams;
|
||||
DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, const uint32_t calCount)
|
||||
{
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
TBuffAddr srcAddr, dstAddr;
|
||||
srcAddr.bufferAddr = reinterpret_cast<uint64_t>(src);
|
||||
dstAddr.bufferAddr = reinterpret_cast<uint64_t>(dst);
|
||||
srcTensor.SetAddr(srcAddr);
|
||||
dstTensor.SetAddr(dstAddr);
|
||||
DataCopy(dstTensor, srcTensor, calCount);
|
||||
}
|
||||
|
||||
#endif // CAM_DATACOPY_GM2GM_H
|
||||
199
csrc/utils/inc/kernel/moe_distribute_base.h
Normal file
199
csrc/utils/inc/kernel/moe_distribute_base.h
Normal file
@@ -0,0 +1,199 @@
|
||||
/*!
|
||||
* \file moe_distribute_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef MOE_DISTRIBUTE_BASE_H
|
||||
#define MOE_DISTRIBUTE_BASE_H
|
||||
|
||||
/* system tick: 50MHz */
|
||||
#define CAL_US(tick) (((tick) * 2) / 100)
|
||||
|
||||
/* performance macro */
|
||||
// #define USE_256_TO_1__ // Enable 256 to 1
|
||||
#ifdef USE_256_TO_1__
|
||||
#pragma message("use 256 to 1")
|
||||
#else // 256 to 1 is only used as baseline, not combined with other optimization points
|
||||
#define USE_FOR_OPT__ // Enable loop optimization loop optimization
|
||||
#define DISPATCH_USE_WRITE_SHUFFLE__ // Dispatch uses write shuffle
|
||||
#define USE_TOKEN_COUNT_SPLIT__ // Enable separation of token and count flags token and count flags
|
||||
#define USE_ONE_CORE_WAIT__ // Enable single core wait
|
||||
|
||||
#ifdef USE_ONE_CORE_WAIT__
|
||||
#pragma message("use one core wait")
|
||||
// Enable single core cumsum calculation
|
||||
// #define USE_ONE_CORE_GETCUMSUM__
|
||||
#endif
|
||||
#ifdef USE_FOR_OPT__
|
||||
#pragma message("use for optimization")
|
||||
#define FOR_OPT_MAX_BS__ 64
|
||||
#define FOR_OPT_MAX_MOE_RANK__ 256
|
||||
#endif
|
||||
// #define COMBINE_USE_DYNAMIC_QUANT // Combine quantization is disabled by default
|
||||
#define OPT_RANK_OFFSET 512
|
||||
#define USE_WRITE_SHUFFLE
|
||||
#endif
|
||||
|
||||
constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64;
|
||||
constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19;
|
||||
constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2;
|
||||
constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024;
|
||||
|
||||
struct HcclSignalInfo {
|
||||
uint64_t resId; // EventId when representing event, notifyId when representing notify
|
||||
uint64_t addr;
|
||||
uint32_t devId;
|
||||
uint32_t tsId;
|
||||
uint32_t rankId;
|
||||
uint32_t flag;
|
||||
};
|
||||
|
||||
struct ListCommon {
|
||||
uint64_t nextHost;
|
||||
uint64_t preHost;
|
||||
uint64_t nextDevice;
|
||||
uint64_t preDevice;
|
||||
};
|
||||
|
||||
struct HcclStreamInfo {
|
||||
int32_t streamIds;
|
||||
uint32_t sqIds;
|
||||
uint32_t cqIds; // Record physical cqId
|
||||
uint32_t logicCqids; // Record logical cqId
|
||||
};
|
||||
|
||||
struct LocalResInfoV2 {
|
||||
uint32_t streamNum;
|
||||
uint32_t signalNum;
|
||||
HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM];
|
||||
HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM];
|
||||
HcclStreamInfo mainStreamInfo;
|
||||
HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; // Collective communication AICPU expanded resources
|
||||
ListCommon nextTagRes; // HccltagLocalResV2
|
||||
};
|
||||
|
||||
enum class rtFloatOverflowMode_t {
|
||||
RT_OVERFLOW_MODE_SATURATION = 0,
|
||||
RT_OVERFLOW_MODE_INFNAN,
|
||||
RT_OVERFLOW_MODE_UNDEF,
|
||||
};
|
||||
|
||||
struct AlgoTopoInfo {
|
||||
uint32_t userRank; // Communication domain RankID
|
||||
uint32_t userRankSize; // Number of Ranks in communication domain
|
||||
int32_t deviceLogicId;
|
||||
bool isSingleMeshAggregation;
|
||||
uint32_t deviceNumPerAggregation; // Number of Devices in each Module
|
||||
uint32_t superPodNum; // Total number of super nodes in cluster
|
||||
uint32_t devicePhyId;
|
||||
uint32_t topoType; // TopoType
|
||||
uint32_t deviceType;
|
||||
uint32_t serverNum;
|
||||
uint32_t meshAggregationRankSize;
|
||||
uint32_t multiModuleDiffDeviceNumMode;
|
||||
uint32_t multiSuperPodDiffServerNumMode;
|
||||
uint32_t realUserRank;
|
||||
bool isDiffDeviceModule;
|
||||
bool isDiffDeviceType;
|
||||
uint32_t gcdDeviceNumPerAggregation;
|
||||
uint32_t moduleNum;
|
||||
uint32_t isUsedRdmaRankPairNum;
|
||||
uint64_t isUsedRdmaRankPair;
|
||||
uint32_t pairLinkCounterNum;
|
||||
uint64_t pairLinkCounter;
|
||||
uint32_t nicNum;
|
||||
uint64_t nicList; // Pointer to niclist array
|
||||
uint64_t complanRankLength; // Bytes occupied by complanRank
|
||||
uint64_t complanRank; // Pointer
|
||||
uint64_t bridgeRankNum; // Number of bridgeRank entries
|
||||
uint64_t bridgeRank; // Pointer
|
||||
uint64_t serverAndsuperPodRankLength; // Bytes occupied by serverAndsuperPodRank
|
||||
uint64_t serverAndsuperPodRank; // Pointer
|
||||
};
|
||||
|
||||
struct HcclOpConfig {
|
||||
uint8_t deterministic; // Deterministic computation switch
|
||||
uint8_t retryEnable; // Whether to retry execution
|
||||
uint8_t highPerfEnable;
|
||||
uint8_t padding[5]; // Size needs 64-byte alignment, reduce padding when adding parameters in future
|
||||
uint8_t linkTimeOut[8]; // Send timeout duration
|
||||
uint64_t notifyWaitTime; // Timeout duration, same as HCCL_EXEC_TIMEOUTas HCCL_EXEC_TIMEOUT
|
||||
uint32_t retryHoldTime;
|
||||
uint32_t retryIntervalTime;
|
||||
bool interHccsDisable = false; // Enable RDMA switch
|
||||
rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF;
|
||||
uint32_t multiQpThreshold = 512; // Minimum data amount threshold for each QP in multi-QP mode
|
||||
};
|
||||
|
||||
struct HcclMC2WorkSpace {
|
||||
uint64_t workSpace;
|
||||
uint64_t workSpaceSize;
|
||||
};
|
||||
|
||||
struct RemoteResPtr {
|
||||
uint64_t nextHostPtr;
|
||||
uint64_t nextDevicePtr;
|
||||
};
|
||||
|
||||
struct HDCommunicateParams {
|
||||
uint64_t hostAddr { 0 };
|
||||
uint64_t deviceAddr { 0 };
|
||||
uint64_t readCacheAddr { 0 };
|
||||
uint32_t devMemSize{ 0 };
|
||||
uint32_t buffLen{ 0 };
|
||||
uint32_t flag{ 0 };
|
||||
};
|
||||
|
||||
struct HcclRankRelationResV2 {
|
||||
uint32_t remoteUsrRankId;
|
||||
uint32_t remoteWorldRank;
|
||||
uint64_t windowsIn;
|
||||
uint64_t windowsOut;
|
||||
uint64_t windowsExp;
|
||||
ListCommon nextTagRes;
|
||||
};
|
||||
|
||||
struct HcclOpResParam {
|
||||
// Local resources
|
||||
HcclMC2WorkSpace mc2WorkSpace;
|
||||
uint32_t localUsrRankId; // usrrankid
|
||||
uint32_t rankSize; // Total number of ranks in communication domain
|
||||
uint64_t winSize; // Size of each window, may be 0 for static graphs, may be non-zero if dynamic graphs exist in communication domain
|
||||
uint64_t localWindowsIn; // All F means invalid value
|
||||
uint64_t localWindowsOut; // All F means invalid value
|
||||
char hcomId[128];
|
||||
// AICore identifies remote window
|
||||
uint64_t winExpSize;
|
||||
uint64_t localWindowsExp;
|
||||
uint32_t rWinStart; // Start position for HcclRankRelationRes
|
||||
uint32_t rWinOffset; // Size of HcclRemoteRes
|
||||
uint64_t version;
|
||||
LocalResInfoV2 localRes;
|
||||
AlgoTopoInfo topoInfo;
|
||||
|
||||
// External configuration parameters
|
||||
HcclOpConfig config;
|
||||
uint64_t hostStateInfo;
|
||||
uint64_t aicpuStateInfo;
|
||||
uint64_t lockAddr;
|
||||
uint32_t rsv[16];
|
||||
uint32_t notifysize; // Used in RDMA scenarios, 4B for 910B/910_93, 8B for other chips
|
||||
uint32_t remoteResNum; // Valid remoteResNum
|
||||
RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; // Array pointer, points to HcclRankRelationResV2, index is remoteUserRankId
|
||||
|
||||
// communicate retry
|
||||
HDCommunicateParams kfcControlTransferH2DParams;
|
||||
HDCommunicateParams kfcStatusTransferD2HParams;
|
||||
uint64_t tinyMem; // for all2all
|
||||
uint64_t tinyMemSize;
|
||||
// Used in zero-copy scenarios
|
||||
uint64_t zeroCopyHeadPtr;
|
||||
uint64_t zeroCopyTailPtr;
|
||||
uint64_t zeroCopyRingBuffer;
|
||||
uint64_t zeroCopyIpcPtrs[16]; // Save input/output memory addresses of each peer during collective communication
|
||||
uint32_t zeroCopyDevicePhyId[16]; // Save physical card ID corresponding to each rank
|
||||
|
||||
bool utraceStatusFlag;
|
||||
};
|
||||
|
||||
#endif // MOE_DISTRIBUTE_BASE_H
|
||||
426
csrc/utils/inc/kernel/sync_collectives.h
Normal file
426
csrc/utils/inc/kernel/sync_collectives.h
Normal file
@@ -0,0 +1,426 @@
|
||||
#ifndef SYNC_COLLECTIVES_H
|
||||
#define SYNC_COLLECTIVES_H
|
||||
|
||||
#include "comm_args.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
// Synchronization flag occupies length
|
||||
constexpr int64_t FLAG_UNIT_INT_NUM = 4;
|
||||
// Memory size occupied by each synchronization unit (Bytes)
|
||||
constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t);
|
||||
// High-order offset when using magic as a comparison value
|
||||
constexpr int64_t MAGIC_OFFSET = 32;
|
||||
constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1);
|
||||
|
||||
class SyncCollectives {
|
||||
public:
|
||||
__aicore__ inline SyncCollectives() {}
|
||||
|
||||
__aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, TBuf<QuePosition::VECCALC> &tBuf)
|
||||
{
|
||||
this->rank = rank;
|
||||
this->rankSize = rankSize;
|
||||
this->shareAddrs = shareAddrs;
|
||||
this->blockIdx = GetBlockIdx();
|
||||
this->blockNum = GetBlockNum();
|
||||
// Length of a single indicator segment
|
||||
segmentCount = GetBlockNum() * FLAG_UNIT_INT_NUM;
|
||||
// Initialize the intra-card/inter-card synchronization address corresponding to the current core.
|
||||
localSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]);
|
||||
basicSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + GetBlockIdx() * FLAG_UNIT_INT_NUM;
|
||||
blockOuterSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + segmentCount + GetBlockIdx() * FLAG_UNIT_INT_NUM;
|
||||
this->tBuf = tBuf;
|
||||
}
|
||||
|
||||
__aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID)
|
||||
{
|
||||
int64_t v = MergeMagicWithValue(magic, value);
|
||||
SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the flag for the specified eventID of the designated card, with the value being a combination of magic and value.
|
||||
* @param magic The operator batch, which will be combined into the high 32 bits of the flag value to be set.
|
||||
* @param value The specific value to be set, which will be the low 32 bits of the flag value to be set.
|
||||
* @param eventID Physically, it is an offset from the shared memory base address (requires scaling, not an absolute value).
|
||||
* @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs structure, not a global or local id.
|
||||
* (Local is not applicable in the 91093 scenario, and global is not applicable in the 910B multi-machine scenario.)
|
||||
*/
|
||||
__aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank)
|
||||
{
|
||||
int64_t v = MergeMagicWithValue(magic, value);
|
||||
SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v);
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId)
|
||||
{
|
||||
return (blockMultiplier * blockNum) + targetCoreId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Wait for the flag of the specified eventID on the specified card to become a value
|
||||
* composed of the combination of magic and value.
|
||||
* @param magic The operator batch, which will be combined into the high 32 bits of the flag
|
||||
* value to be wait.
|
||||
* @param value The specific value to be wait, which will be the low 32 bits of the flag
|
||||
* value to be wait.
|
||||
* @param eventID Physically, it is an offset from the shared memory base address (requires
|
||||
* scaling, not an absolute value).
|
||||
* @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs
|
||||
* structure, not a global or local id. (Local is not applicable in the 91093
|
||||
* scenario, and global is not applicable in the 910B multi-machine scenario.)
|
||||
*/
|
||||
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank)
|
||||
{
|
||||
int64_t v = MergeMagicWithValue(magic, value);
|
||||
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v);
|
||||
}
|
||||
|
||||
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID)
|
||||
{
|
||||
int64_t v = MergeMagicWithValue(magic, value);
|
||||
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Wait for the flags starting from the specified eventID on the specified card to become
|
||||
* a value composed of the combination of magic and value.<br>
|
||||
* Note: [eventID, eventID + flagNum)
|
||||
*/
|
||||
__aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum)
|
||||
{
|
||||
int64_t v = MergeMagicWithValue(magic, value);
|
||||
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v);
|
||||
}
|
||||
|
||||
// Set inner-card synchronization flag (memory A)
|
||||
__aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
SetFlag(basicSyncAddr, value);
|
||||
}
|
||||
|
||||
__aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
SetFlag((__gm__ int64_t*)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value);
|
||||
}
|
||||
|
||||
// Wait for a single inner-card synchronization flag (memory A)
|
||||
__aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value);
|
||||
}
|
||||
|
||||
// Wait for all inner-card synchronization flags within the entire rank (memory A)
|
||||
__aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
WaitOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value);
|
||||
}
|
||||
|
||||
// Check all inner-card synchronization flags within the entire rank (memory A)
|
||||
__aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
return CheckOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value);
|
||||
}
|
||||
|
||||
// Set inter-card synchronization flag (memory B)
|
||||
__aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
SetFlag(blockOuterSyncAddr, value);
|
||||
}
|
||||
|
||||
__aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock)
|
||||
{
|
||||
__gm__ int64_t* flagAddr = GetOuterFlagAddr(setRank, setBlock);
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
SetFlag(flagAddr, value);
|
||||
}
|
||||
|
||||
// Wait for a single inter-card synchronization flag (memory B)
|
||||
__aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
__gm__ int64_t* flagAddr = GetOuterFlagAddr(waitRank, waitBlock);
|
||||
WaitOneRankPartFlag(flagAddr, 1, value);
|
||||
}
|
||||
|
||||
// Wait for all inter-card synchronization flags within the entire rank (memory B)
|
||||
__aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
__gm__ int64_t* flagAddr;
|
||||
flagAddr = GetOuterFlagAddr(rank, 0);
|
||||
WaitOneRankPartFlag(flagAddr, blockNum, value);
|
||||
}
|
||||
|
||||
// Wait for flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B)
|
||||
__aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
__gm__ int64_t* flagAddr;
|
||||
int waitRank;
|
||||
for (auto r = 0; r < rankSize; ++r) {
|
||||
waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores
|
||||
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
|
||||
WaitOneRankPartFlag(flagAddr, flagNum, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Check flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B)
|
||||
__aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock,
|
||||
int64_t flagNum)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
__gm__ int64_t* flagAddr;
|
||||
int waitRank;
|
||||
for (auto r = 0; r < rankSize; ++r) {
|
||||
waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from concurrent copying by multiple cores
|
||||
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
|
||||
if (!CheckOneRankPartFlag(flagAddr, flagNum, value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Wait for all inter-card synchronization flags for all ranks, full rank synchronization (memory B)
|
||||
__aicore__ inline void WaitAllRankOuterFlag(int32_t magic, int32_t eventID)
|
||||
{
|
||||
WaitAllRankPartOuterFlag(magic, eventID, 0, blockNum);
|
||||
}
|
||||
|
||||
// Check all inter-card synchronization flags for all ranks, full rank synchronization (memory B)
|
||||
__aicore__ inline bool CheckAllRankOuterFlag(int32_t magic, int32_t eventID)
|
||||
{
|
||||
return CheckAllRankPartOuterFlag(magic, eventID, 0, blockNum);
|
||||
}
|
||||
|
||||
// Low-level interface, set synchronization flag
|
||||
__aicore__ inline void SetFlag(__gm__ int64_t* setAddr, int64_t setValue)
|
||||
{
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
GlobalTensor<int64_t> globalSet;
|
||||
globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM);
|
||||
LocalTensor<int64_t> localSet = tBuf.GetWithOffset<int64_t>(1, 0);
|
||||
localSet.SetValue(0, setValue);
|
||||
|
||||
// Copy global synchronization flag to local
|
||||
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0); // Wait for SetValue to complete
|
||||
DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM);
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0); // Wait for UB->GM to complete
|
||||
}
|
||||
|
||||
// Low-level interface, wait for synchronization flag
|
||||
__aicore__ inline void WaitFlag(__gm__ int64_t* waitAddr, int64_t waitValue)
|
||||
{
|
||||
WaitOneRankPartFlag(waitAddr, 1, waitValue);
|
||||
}
|
||||
|
||||
// Read a flag, return an immediate number
|
||||
__aicore__ inline int64_t GetFlag(__gm__ int64_t* waitAddr)
|
||||
{
|
||||
GlobalTensor<int64_t> globalWait;
|
||||
globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM);
|
||||
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(1, 0);
|
||||
// Copy global to local
|
||||
DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
|
||||
|
||||
int64_t res = localWait.GetValue(0);
|
||||
return res;
|
||||
}
|
||||
|
||||
// Get multiple consecutive synchronization flags within a single card
|
||||
__aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank,
|
||||
int64_t startBlock, int64_t flagNum)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventID);
|
||||
__gm__ int64_t* flagAddr;
|
||||
flagAddr = GetOuterFlagAddr(waitRank, startBlock);
|
||||
WaitOneRankPartFlag(flagAddr, flagNum, value);
|
||||
}
|
||||
|
||||
// Get synchronization flag within a single card (memory A)
|
||||
__aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock)
|
||||
{
|
||||
return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM);
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock)
|
||||
{
|
||||
return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM);
|
||||
}
|
||||
|
||||
// In the rank Chunk Flag area, return success if the destRank chunk Flag value is 0, otherwise fail
|
||||
__aicore__ inline int64_t GetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, 0);
|
||||
int64_t status = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) +
|
||||
IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value, timeout);
|
||||
return status;
|
||||
}
|
||||
|
||||
// Set the destRank chunk Flag value in the rank Chunk Flag area to value
|
||||
__aicore__ inline void SetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t eventId)
|
||||
{
|
||||
int64_t value = MergeMagicWithValue(magic, eventId);
|
||||
SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value);
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t GetChunkRecvLen(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout)
|
||||
{
|
||||
int64_t len = GetChunkFlagValue((__gm__ int64_t*)(shareAddrs[rank]) + IPC_CHUNK_FLAG +
|
||||
destRank * FLAG_UNIT_INT_NUM, 0, timeout, true, magic);
|
||||
return len;
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value)
|
||||
{
|
||||
// Merge magic as the high bits and eventID as the low bits into a value for comparison
|
||||
return (static_cast<int64_t>(static_cast<uint32_t>(magic)) << MAGIC_OFFSET) | static_cast<int64_t>(value);
|
||||
}
|
||||
|
||||
__aicore__ inline __gm__ int64_t* GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock)
|
||||
{
|
||||
return (__gm__ int64_t*)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM;
|
||||
}
|
||||
|
||||
__aicore__ inline __gm__ int64_t* GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock)
|
||||
{
|
||||
return (__gm__ int64_t*)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM;
|
||||
}
|
||||
|
||||
// Wait for a part of synchronization flags within a rank
|
||||
__aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue)
|
||||
{
|
||||
GlobalTensor<int64_t> globalWait;
|
||||
globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM);
|
||||
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(flagNum * FLAG_UNIT_INT_NUM, 0);
|
||||
bool isSync = true;
|
||||
int64_t checkedFlagNum = 0;
|
||||
do {
|
||||
// Copy global synchronization flags to local
|
||||
DataCopy(localWait, globalWait[checkedFlagNum * FLAG_UNIT_INT_NUM],
|
||||
(flagNum - checkedFlagNum) * FLAG_UNIT_INT_NUM);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
|
||||
|
||||
// Check if the synchronization flags are equal to checkValue
|
||||
isSync = true;
|
||||
int64_t remainToCheck = flagNum - checkedFlagNum;
|
||||
for (auto i = 0; i < remainToCheck; ++i) {
|
||||
// Continue waiting if any core has not reached the checkValue phase
|
||||
int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM);
|
||||
if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) {
|
||||
isSync = false;
|
||||
checkedFlagNum += i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} while (!isSync);
|
||||
}
|
||||
|
||||
// Wait for all synchronization flags within a rank
|
||||
__aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue)
|
||||
{
|
||||
WaitOneRankPartFlag(waitAddr, blockNum, checkValue);
|
||||
}
|
||||
|
||||
// Check partial synchronization flags within a rank, copy only once
|
||||
__aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue)
|
||||
{
|
||||
GlobalTensor<int64_t> globalWait;
|
||||
globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM);
|
||||
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(flagNum * FLAG_UNIT_INT_NUM, 0);
|
||||
// Copy global synchronization flags to local
|
||||
DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
|
||||
// Check if the synchronization flags are equal to checkValue
|
||||
bool isSync = true;
|
||||
for (auto i = 0; i < flagNum; ++i) {
|
||||
// Continue waiting if any core has not reached the checkValue phase
|
||||
int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM);
|
||||
if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) {
|
||||
isSync = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return isSync;
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t GetChunkFlagValue(__gm__ int64_t* waitAddr, int64_t checkValue, int64_t timeout,
|
||||
bool checkNonZero = false, int64_t magic = 0)
|
||||
{
|
||||
GlobalTensor<int64_t> globalWait;
|
||||
globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM);
|
||||
LocalTensor<int64_t> localWait = tBuf.GetWithOffset<int64_t>(FLAG_UNIT_INT_NUM, 0);
|
||||
bool isSync = true;
|
||||
|
||||
int64_t waitTimes = 0;
|
||||
int64_t v = 0;
|
||||
|
||||
do {
|
||||
// Copy global sync flag to local
|
||||
DataCopy(localWait, globalWait[0], FLAG_UNIT_INT_NUM);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0); // Wait for GM->UB
|
||||
|
||||
isSync = true;
|
||||
v = localWait.GetValue(0);
|
||||
if (checkNonZero) {
|
||||
// Non-zero check mode
|
||||
if (((v & MAGIC_MASK) == (static_cast<int64_t>(magic) << MAGIC_OFFSET)) && (v & 0xFFFFFFFF)) {
|
||||
return v & 0xFFFFFFFF; // Return lower 32 bits when non-zero
|
||||
}
|
||||
} else {
|
||||
// Exact value check mode
|
||||
if (v == checkValue) {
|
||||
return WAIT_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
isSync = false;
|
||||
waitTimes++;
|
||||
|
||||
if (timeout > INT64_MAX / MAX_WAIT_ROUND_UNIT || waitTimes >= (timeout * MAX_WAIT_ROUND_UNIT)) {
|
||||
isSync = true;
|
||||
return v; // Return the read flag value
|
||||
}
|
||||
} while (!isSync);
|
||||
|
||||
return checkNonZero ? 0 : v;
|
||||
}
|
||||
|
||||
// Check all sync flags within a rank, copy only once
|
||||
__aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue)
|
||||
{
|
||||
return CheckOneRankPartFlag(waitAddr, blockNum, checkValue);
|
||||
}
|
||||
int rank;
|
||||
int rankSize;
|
||||
int blockIdx;
|
||||
int blockNum;
|
||||
GM_ADDR *shareAddrs;
|
||||
int64_t segmentCount; // Length of a single sync flag segment (count in int64_t)
|
||||
__gm__ int64_t* localSyncAddr;
|
||||
__gm__ int64_t* basicSyncAddr; // Intra-card sync flag address for the current block
|
||||
__gm__ int64_t* blockOuterSyncAddr; // Inter-card sync flag address for the current block
|
||||
TBuf<QuePosition::VECCALC> tBuf;
|
||||
};
|
||||
|
||||
#endif // SYNC_COLLECTIVES_H
|
||||
Reference in New Issue
Block a user