Revert "[Kernel] add custom moe ops for prefill" (#4806)

Reverts vllm-project/vllm-ascend#4194 as it broke CI in
https://github.com/vllm-project/vllm-ascend/actions/runs/20030369087/job/57437687382?pr=4791

Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Mengqing Cao
2025-12-08 23:20:32 +08:00
committed by GitHub
parent 432b861cae
commit 7e70da9fb7
39 changed files with 2 additions and 5562 deletions

View File

@@ -1,49 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME MoeCombineNormal
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnnInner PRIVATE
moe_combine_normal.cpp
)
target_sources(opapi PRIVATE
aclnn_moe_combine_normal.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_moe_combine_normal.cpp
)
target_sources(aclnn_ops_infer PRIVATE
aclnn_moe_combine_normal.cpp
)
endif ()
target_sources(optiling PRIVATE
moe_combine_normal_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_combine_normal.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -1,77 +0,0 @@
#include <string.h>
#include "graph/types.h"
#include "aclnn_moe_combine_normal.h"
enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
#ifdef __cplusplus
extern "C" {
#endif
extern aclnnStatus aclnnInnerMoeCombineNormalGetWorkspaceSize(
const aclTensor *recvX,
const aclTensor *tokenSrcInfo,
const aclTensor *epRecvCounts,
const aclTensor *recvTopkWeights,
const aclTensor *tpRecvCountsOptional,
char *epGroupName,
int64_t epWorldSize,
int64_t epRankId,
char *tpGroupNameOptional,
int64_t tpWorldSize,
int64_t tpRankId,
int64_t moeExpertNum,
int64_t globalBs,
const aclTensor *out,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerMoeCombineNormal(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize(
const aclTensor *recvX,
const aclTensor *tokenSrcInfo,
const aclTensor *epRecvCounts,
const aclTensor *recvTopkWeights,
const aclTensor *tpRecvCountsOptional,
char *epGroupName,
int64_t epWorldSize,
int64_t epRankId,
char *tpGroupNameOptional,
int64_t tpWorldSize,
int64_t tpRankId,
int64_t moeExpertNum,
int64_t globalBs,
const aclTensor *out,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerMoeCombineNormalGetWorkspaceSize(recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights,
tpRecvCountsOptional, epGroupName, epWorldSize, epRankId,
tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum,
globalBs, out, workspaceSize, executor);
}
aclnnStatus aclnnMoeCombineNormal(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
return aclnnInnerMoeCombineNormal(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,62 +0,0 @@
#ifndef ACLNN_MOE_COMBINE_NORMAL_H_
#define ACLNN_MOE_COMBINE_NORMAL_H_
#include "aclnn/acl_meta.h"
#ifdef __cplusplus
extern "C" {
#endif
/* funtion: aclnnMoeCombineGetWorkspaceSize
* recvX : required
* tokenSrcInfo : required
* epRecvCounts : required
* recvTopkWeights : required
* tpRecvCountsOptional : required
* epGroupName : optional
* epWorldSize : required
* epRankId : required
* tpGroupNameOptional : required
* tpWorldSize : optional
* tpRankId : optional
* moeExpertNum : optional
* globalBs : optional
* out : required
* workspaceSize : size of workspace(output).
* executor : executor context(output).
*/
__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormalGetWorkspaceSize(
const aclTensor *recvX,
const aclTensor *tokenSrcInfo,
const aclTensor *epRecvCounts,
const aclTensor *recvTopkWeights,
const aclTensor *tpRecvCountsOptional,
char *epGroupName,
int64_t epWorldSize,
int64_t epRankId,
char *tpGroupNameOptional,
int64_t tpWorldSize,
int64_t tpRankId,
int64_t moeExpertNum,
int64_t globalBs,
const aclTensor *out,
uint64_t *workspaceSize,
aclOpExecutor **executor);
/* funtion: aclnnMoeCombine
* workspace : workspace memory addr(input).
* workspaceSize : size of workspace(input).
* executor : executor context(input).
* stream : acl stream.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnMoeCombineNormal(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,71 +0,0 @@
#include "register/op_def_registry.h"
namespace ops {
class MoeCombineNormal : public OpDef {
public:
explicit MoeCombineNormal(const char* name) : OpDef(name) {
this->Input("recv_x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("token_src_info")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("ep_recv_counts")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("recv_topk_weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("tp_recv_counts")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Output("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("ep_group_name").AttrType(REQUIRED).String();
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();
this->Attr("tp_group_name").AttrType(OPTIONAL).String("");
this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0);
this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0);
this->Attr("moe_expert_num").AttrType(REQUIRED).Int();
this->Attr("global_bs").AttrType(OPTIONAL).Int(0);
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_true")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
this->AICore().AddConfig("ascend910_93", aicore_config);
this->MC2().HcclGroup({"ep_group_name", "tp_group_name"});
}
};
OP_ADD(MoeCombineNormal);
} // namespace ops

View File

@@ -1,546 +0,0 @@
#include <queue>
#include <vector>
#include <dlfcn.h>
#include <fcntl.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cmath>
#include <cstdint>
#include <string>
#include <type_traits>
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "log/ops_log.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
#include "../op_kernel/moe_combine_normal_tiling.h"
using namespace AscendC;
using namespace ge;
namespace {
class Mc2TilingUtils {
public:
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
static uint64_t GetMaxWindowSize()
{
uint16_t defaultWindowSize = 200;
if (getenv(HCCL_BUFFSIZE) == nullptr) {
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
} else {
try {
std::string envStr(getenv(HCCL_BUFFSIZE));
defaultWindowSize = std::stoi(envStr);
} catch (...) {
OPS_LOG_E("", "Unknown Exception encountered when parser env HCCL_BUFFERSIZE");
}
}
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
return maxWindowSize;
}
};
constexpr uint32_t RECV_X_INDEX = 0;
constexpr uint32_t TOKEN_SRC_INFO_INDEX = 1;
constexpr uint32_t EP_RECV_COUNTS_INDEX = 2;
constexpr uint32_t TOPK_WEIGHTS_INDEX = 3;
constexpr uint32_t TP_RECV_COUNTS_INDEX = 4;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_GROUP_TP_INDEX = 3;
constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4;
constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
constexpr uint32_t TWO_DIMS = 2U;
constexpr uint32_t ONE_DIM = 1U;
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll
constexpr uint32_t OP_TYPE_REDUCE_SCATTER = 7U; // numeric representation of ReduceScatter
constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
constexpr int64_t BS_UPPER_BOUND = 8000;
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes
constexpr int64_t MOE_EXPERT_MAX_NUM = 512;
constexpr int64_t K_MAX = 16;
constexpr int64_t H_MIN = 1024;
constexpr int64_t H_MAX = 7168;
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
constexpr uint64_t TRIPLE = 3;
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
constexpr uint64_t SCALE_RECV_IDX_BUFFER = 44UL; // scale32B + 3*4 src info
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL;
constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
constexpr uint64_t UB_ALIGN = 32UL;
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
enum class CommQuantMode : int32_t {
NON_QUANT = 0,
INT12_QUANT = 1,
INT8_QUANT = 2
};
using CommQuantModeType = std::underlying_type<CommQuantMode>;
}
namespace optiling {
// Specific to A3
static void PrintTilingDataInfo(const char *nodeName, MoeCombineNormalTilingData& tilingData)
{
OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeCombineNormalInfo.epWorldSize);
OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeCombineNormalInfo.tpWorldSize);
OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeCombineNormalInfo.epRankId);
OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeCombineNormalInfo.tpRankId);
OPS_LOG_D(nodeName, "expertShardType is %u.", tilingData.moeCombineNormalInfo.expertShardType);
OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeCombineNormalInfo.moeExpertNum);
OPS_LOG_D(nodeName, "moeExpertPerRankNum is %u.", tilingData.moeCombineNormalInfo.moeExpertPerRankNum);
OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeCombineNormalInfo.globalBs);
OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeCombineNormalInfo.bs);
OPS_LOG_D(nodeName, "k is %u.", tilingData.moeCombineNormalInfo.k);
OPS_LOG_D(nodeName, "h is %u.", tilingData.moeCombineNormalInfo.h);
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeCombineNormalInfo.aivNum);
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeCombineNormalInfo.totalUbSize);
OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeCombineNormalInfo.totalWinSize);
}
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
const char *nodeName, std::string &groupEp, std::string &groupTp)
{
auto attrs = context->GetAttrs();
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return ge::GRAPH_FAILED);
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
auto groupTpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_TP_INDEX));
auto epWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_WORLD_SIZE_INDEX);
auto tpWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_WORLD_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto tpRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
// Check for null
OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
(strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), OPS_LOG_E(nodeName, "groupEp is invalid."),
return ge::GRAPH_FAILED);
OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSize is null."), return ge::GRAPH_FAILED);
OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSize is null."), return ge::GRAPH_FAILED);
OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankId is null."), return ge::GRAPH_FAILED);
OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankId is null."), return ge::GRAPH_FAILED);
OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNum is null."), return ge::GRAPH_FAILED);
// Check if it meets uint32_t and other constraints
int64_t moeExpertNum = *moeExpertNumPtr;
int64_t epWorldSize = *epWorldSizePtr;
OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE),
OPS_LOG_E(nodeName, "epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.",
MIN_EP_WORLD_SIZE, MAX_EP_WORLD_SIZE, epWorldSize), return ge::GRAPH_FAILED);
OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE),
OPS_LOG_E(nodeName, "tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.",
MAX_TP_WORLD_SIZE, *tpWorldSizePtr), return ge::GRAPH_FAILED);
OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize),
OPS_LOG_E(nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.",
epWorldSize, *epRankIdPtr), return ge::GRAPH_FAILED);
if (*tpWorldSizePtr > 1) {
OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr),
OPS_LOG_E(nodeName, "tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.",
*tpWorldSizePtr, *tpRankIdPtr), return ge::GRAPH_FAILED);
OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
(strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
OPS_LOG_E(nodeName, "groupTpPtr is null."), return ge::GRAPH_FAILED);
groupTp = std::string(groupTpPtr);
} else {
OPS_CHECK(*tpRankIdPtr != 0,
OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr),
return ge::GRAPH_FAILED);
}
OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM),
OPS_LOG_E(nodeName, "moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.",
MOE_EXPERT_MAX_NUM, moeExpertNum), return ge::GRAPH_FAILED);
int64_t moePerRankNum = moeExpertNum / epWorldSize;
int64_t curDispatchStatusNum = moePerRankNum * epWorldSize;
OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM),
OPS_LOG_E(nodeName, "The moe experts num must meet the conditions,"
" (moeExpertNum / epWorldSize) * epWorldSize <= 1280, but cur is %ld.",
curDispatchStatusNum), return ge::GRAPH_FAILED);
groupEp = std::string(groupEpPtr);
tilingData.moeCombineNormalInfo.epWorldSize = static_cast<uint32_t>(epWorldSize);
tilingData.moeCombineNormalInfo.tpWorldSize = static_cast<uint32_t>(*tpWorldSizePtr);
tilingData.moeCombineNormalInfo.epRankId = static_cast<uint32_t>(*epRankIdPtr);
tilingData.moeCombineNormalInfo.tpRankId = static_cast<uint32_t>(*tpRankIdPtr);
tilingData.moeCombineNormalInfo.moeExpertNum = static_cast<uint32_t>(moeExpertNum);
return ge::GRAPH_SUCCESS;
}
static bool CheckInputTensorDim(gert::TilingContext *context, const char *nodeName)
{
const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX);
OPS_CHECK(recvXStorageShape == nullptr, OPS_LOG_E(nodeName, "recvX is null."), return false);
OPS_CHECK(recvXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "recvX must be 2-dimension, but got %lu dim",
recvXStorageShape->GetStorageShape().GetDimNum()), return false);
OPS_LOG_D(nodeName, "recvX dim0 = %ld", recvXStorageShape->GetStorageShape().GetDim(0));
OPS_LOG_D(nodeName, "recvX dim1 = %ld", recvXStorageShape->GetStorageShape().GetDim(1));
const gert::StorageShape *tokenSrcInfoStorageShape = context->GetInputShape(TOKEN_SRC_INFO_INDEX);
OPS_CHECK(tokenSrcInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoForCombine is null."), return false);
OPS_CHECK(tokenSrcInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
OPS_LOG_E(nodeName, "tokenSrcInfoForCombine must be 1-dimension, but got %lu dim",
tokenSrcInfoStorageShape->GetStorageShape().GetDimNum()), return false);
OPS_LOG_D(nodeName, "tokenSrcInfoForCombine dim0 = %ld", tokenSrcInfoStorageShape->GetStorageShape().GetDim(0));
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
OPS_CHECK(topkWeightsStorageShape == nullptr, OPS_LOG_E(nodeName, "topkWeights is null."), return false);
OPS_CHECK(topkWeightsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "topkWeights must be 2-dimension, but got %lu dim",
topkWeightsStorageShape->GetStorageShape().GetDimNum()), return false);
OPS_LOG_D(nodeName, "topkWeights dim0 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(0));
OPS_LOG_D(nodeName, "topkWeights dim1 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(1));
return true;
}
static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char *nodeName)
{
const gert::StorageShape *tpRecvCountsStorageShape = context->GetOptionalInputShape(TP_RECV_COUNTS_INDEX);
OPS_CHECK(tpRecvCountsStorageShape == nullptr, OPS_LOG_E(nodeName, "tpRecvCounts is null."), return false);
OPS_CHECK(tpRecvCountsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
OPS_LOG_E(nodeName, "tpRecvCounts must be 1-dimension, but got %lu dim",
tpRecvCountsStorageShape->GetStorageShape().GetDimNum()), return false);
OPS_LOG_D(nodeName, "tpRecvCounts dim0 = %ld", tpRecvCountsStorageShape->GetStorageShape().GetDim(0));
return true;
}
static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName)
{
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x is null."), return false);
OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "x must be 2-dimension, but got %lu dim", xStorageShape->GetStorageShape().GetDimNum()),
return false);
OPS_LOG_D(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0));
OPS_LOG_D(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1));
return true;
}
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName)
{
OPS_CHECK(!CheckInputTensorDim(context, nodeName),
OPS_LOG_E(nodeName, "param shape of input tensor is invalid"), return false);
OPS_CHECK(!CheckOptionalInputTensorDim(context, nodeName),
OPS_LOG_E(nodeName, "param shape of optional input tensor is invalid"), return false);
OPS_CHECK(!CheckOutputTensorDim(context, nodeName),
OPS_LOG_E(nodeName, "param shape of output tensor is invalid"), return false);
return true;
}
// Validate data type
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName)
{
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false);
OPS_CHECK((recvXDesc->GetDataType() != ge::DT_BF16) && (recvXDesc->GetDataType() != ge::DT_FLOAT16),
OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is "
), return false);
auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX);
OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false);
OPS_CHECK((tokenSrcInfoDesc->GetDataType() != ge::DT_INT32), OPS_LOG_E(nodeName, "tokenSrcInfoForCombine dataType is invalid,"
" dataType should be int32, but is"), return false);
auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX);
OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false);
OPS_CHECK((tpRecvCountsDesc->GetDataType() != ge::DT_INT32),
OPS_LOG_E(nodeName, "tpRecvCounts dataType is invalid, dataType should be int32, but is "), return false);
auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX);
OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false);
OPS_CHECK((topkWeightsDesc->GetDataType() != ge::DT_FLOAT),
OPS_LOG_E(nodeName, "topkWeights dataType is invalid, dataType should be float, but is "),
return false);
auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
OPS_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()), OPS_LOG_E(nodeName,
"x dataType is invalid, dataType should be equal to recvX dataType , but is "),
return false);
return true;
}
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName)
{
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return false);
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(recvXDesc->GetStorageFormat())) ==
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "recvXFormat is invalid"), return false);
auto tokenSrcInfoDesc = context->GetInputDesc(TOKEN_SRC_INFO_INDEX);
OPS_CHECK(tokenSrcInfoDesc == nullptr, OPS_LOG_E(nodeName, "tokenSrcInfoDesc is null."), return false);
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(tokenSrcInfoDesc->GetStorageFormat())) ==
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tokenSrcInfoFormat is invalid"), return false);
auto tpRecvCountsDesc = context->GetOptionalInputDesc(TP_RECV_COUNTS_INDEX);
OPS_CHECK(tpRecvCountsDesc == nullptr, OPS_LOG_E(nodeName, "tpRecvCountsDesc is null."), return false);
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(tpRecvCountsDesc->GetStorageFormat())) ==
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "tpRecvCountsFormat is invalid"), return false);
auto topkWeightsDesc = context->GetInputDesc(TOPK_WEIGHTS_INDEX);
OPS_CHECK(topkWeightsDesc == nullptr, OPS_LOG_E(nodeName, "topkWeightsDesc is null."), return false);
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(topkWeightsDesc->GetStorageFormat())) ==
ge::FORMAT_FRACTAL_NZ, OPS_LOG_E(nodeName, "topkWeightsFormat is invalid"), return false);
auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
OPS_LOG_E(nodeName, "xFormat is invalid"), return false);
return true;
}
static bool CheckTensorShape(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
const char *nodeName, uint32_t localExpertNum)
{
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0);
int64_t topkWeightsDim1 = topkWeightsStorageShape->GetStorageShape().GetDim(1);
int64_t moeExpertNum = static_cast<int64_t>(tilingData.moeCombineNormalInfo.moeExpertNum);
OPS_CHECK((topkWeightsDim1 <= 0) || (topkWeightsDim1 > K_MAX || (topkWeightsDim1 > moeExpertNum)),
OPS_LOG_E(nodeName, "topkWeights's dim1(K) should be in (0, min(%ld, moeExpertNum %ld)], "
"but got topkWeights's dim1=%ld.", K_MAX, moeExpertNum, topkWeightsDim1), return false);
tilingData.moeCombineNormalInfo.k = static_cast<uint32_t>(topkWeightsDim1);
// Validate recvX dimensions and set h
int64_t tpWorldSize = static_cast<int64_t>(tilingData.moeCombineNormalInfo.tpWorldSize);
const gert::StorageShape *recvXStorageShape = context->GetInputShape(RECV_X_INDEX);
int64_t recvXDim1 = recvXStorageShape->GetStorageShape().GetDim(1);
OPS_CHECK((recvXDim1 < H_MIN) || (recvXDim1 > H_MAX),
OPS_LOG_E(nodeName, "recvX's dim1(H) should be in [%ld, %ld], but got %ld.",
H_MIN, H_MAX, recvXDim1), return false); // 32-byte aligned
tilingData.moeCombineNormalInfo.h = static_cast<uint32_t>(recvXDim1);
// Validate epRecvCount and tpRecvCount dimensions
int64_t epWorldSize = static_cast<int64_t>(tilingData.moeCombineNormalInfo.epWorldSize);
int64_t moeExpertPerRankNum = static_cast<int64_t>(tilingData.moeCombineNormalInfo.moeExpertPerRankNum);
// Validate x dimensions
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
OPS_CHECK(xDim0 != topkWeightsDim0, OPS_LOG_E(nodeName,
"x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false);
OPS_CHECK(xDim1 != recvXDim1, OPS_LOG_E(nodeName,
"x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), return false);
return true;
}
static bool CheckAttrs(gert::TilingContext *context, MoeCombineNormalTilingData &tilingData,
const char *nodeName, uint32_t &localMoeExpertNum)
{
uint32_t epWorldSize = tilingData.moeCombineNormalInfo.epWorldSize;
uint32_t tpWorldSize = tilingData.moeCombineNormalInfo.tpWorldSize;
uint32_t moeExpertNum = tilingData.moeCombineNormalInfo.moeExpertNum;
// Validate if moe expert number can be evenly distributed across multiple machines
OPS_CHECK(moeExpertNum % epWorldSize != 0,
OPS_LOG_E(nodeName, "moeExpertNum should be divisible by epWorldSize, "
"but got moeExpertNum=%d, epWorldSize=%d.", moeExpertNum, epWorldSize), return false);
localMoeExpertNum = moeExpertNum / epWorldSize;
OPS_CHECK(localMoeExpertNum <= 0,
OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum), return false);
// Validate if expert number per card equals 1 when tp=2
OPS_CHECK((localMoeExpertNum > 1) && (tpWorldSize > 1),
OPS_LOG_E(nodeName, "Cannot support multi-moeExpert %d in a rank when tpWorldSize = %d > 1",
localMoeExpertNum, tpWorldSize), return false);
tilingData.moeCombineNormalInfo.moeExpertPerRankNum = localMoeExpertNum;
// Validate topkWeights dimension 0 and set bs
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0);
OPS_CHECK((topkWeightsDim0 <= 0) || (topkWeightsDim0 > BS_UPPER_BOUND),
OPS_LOG_E(nodeName, "Invalid topkWeights dims0(BS) %ld. Should be between [1, %ld].",
topkWeightsDim0, BS_UPPER_BOUND), return false);
tilingData.moeCombineNormalInfo.bs = static_cast<uint32_t>(topkWeightsDim0);
// Validate globalBS
auto attrs = context->GetAttrs();
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is null."), return false);
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBs is null."), return false);
OPS_LOG_D(nodeName, "MoeCombineNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n",
*globalBsPtr, topkWeightsDim0, epWorldSize);
OPS_CHECK((*globalBsPtr != 0) && ((*globalBsPtr < static_cast<int64_t>(epWorldSize) * topkWeightsDim0) ||
((*globalBsPtr) % (static_cast<int64_t>(epWorldSize)) != 0)), OPS_LOG_E(nodeName, "globalBS is invalid, only "
"support 0 or maxBs(maxBs is the largest bs on all ranks) * epWorldSize, but got globalBS=%ld, "
"bs=%ld, epWorldSize=%u.", *globalBsPtr, topkWeightsDim0, epWorldSize), return false);
tilingData.moeCombineNormalInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
if (*globalBsPtr == 0) {
tilingData.moeCombineNormalInfo.globalBs = static_cast<uint32_t>(topkWeightsDim0) * epWorldSize;
}
return true;
}
static ge::graphStatus TilingCheckMoeCombineNormal(gert::TilingContext *context, const char *nodeName)
{
// Check parameter shape information
OPS_CHECK(!CheckTensorDim(context, nodeName),
OPS_LOG_E(nodeName, "param shape is invalid"), return ge::GRAPH_FAILED);
// Check parameter dataType information
OPS_CHECK(!CheckTensorDataType(context, nodeName),
OPS_LOG_E(nodeName, "param dataType is invalid"), return ge::GRAPH_FAILED);
// Check parameter format information
OPS_CHECK(!CheckTensorFormat(context, nodeName),
OPS_LOG_E(nodeName, "param Format is invalid"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus SetWorkspace(gert::TilingContext *context, const char *nodeName)
{
size_t *workspace = context->GetWorkspaceSizes(1);
OPS_CHECK(workspace == nullptr, OPS_LOG_E(nodeName, "get workspace failed"),
return ge::GRAPH_FAILED);
workspace[0] = SYSTEM_NEED_WORKSPACE;
OPS_LOG_D(nodeName, "workspce[0] size is %ld", workspace[0]);
return ge::GRAPH_SUCCESS;
}
static void SetHCommCfg(gert::TilingContext *context, MoeCombineNormalTilingData *tiling,
const std::string groupEp, const std::string groupTp)
{
const char* nodeName = context->GetNodeName();
OPS_LOG_D(nodeName, "MoeCombineNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str());
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
uint32_t opType2 = OP_TYPE_REDUCE_SCATTER;
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
std::string algConfigReduceScatterStr = "ReduceScatter=level0:ring";
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr);
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
mc2CcTilingConfig.SetGroupName(groupTp);
mc2CcTilingConfig.SetOpType(opType2);
mc2CcTilingConfig.SetAlgConfig(algConfigReduceScatterStr);
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2);
}
static ge::graphStatus MoeCombineNormalA3TilingFuncImpl(gert::TilingContext* context)
{
const char *nodeName = context->GetNodeName();
OPS_LOG_D(nodeName, "Enter MoeCombineNormal Tiling func");
MoeCombineNormalTilingData *tilingData = context->GetTilingData<MoeCombineNormalTilingData>();
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
std::string groupEp = "";
std::string groupTp = "";
uint32_t localMoeExpertNum = 1;
// Get input parameter attributes
OPS_CHECK(GetAttrAndSetTilingData(context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED,
OPS_LOG_E(nodeName, "Getting attr failed."), return ge::GRAPH_FAILED);
// Check input/output dim, format, dataType
OPS_CHECK(TilingCheckMoeCombineNormal(context, nodeName) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling check params failed"), return ge::GRAPH_FAILED);
// Check if attribute values are valid
OPS_CHECK(!CheckAttrs(context, *tilingData, nodeName, localMoeExpertNum),
OPS_LOG_E(nodeName, "attr check failed."), return ge::GRAPH_FAILED);
uint32_t epRankId = tilingData->moeCombineNormalInfo.epRankId;
// Check shape dimensions and assign h, k
OPS_CHECK(!CheckTensorShape(context, *tilingData, nodeName, localMoeExpertNum),
OPS_LOG_E(nodeName, "param dim check failed."), return ge::GRAPH_FAILED);
// Validate win area size
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
uint64_t h = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.h);
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.epWorldSize);
uint64_t k = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.k);
uint64_t maxBs = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.globalBs)/ epWorldSize;
// Combine data area: token start address aligned to 512
uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
// Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b)
uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_RECV_IDX_BUFFER;
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET) *
DOUBLE_DATA_BUFFER;
OPS_CHECK((actualSize > maxWindowSize),
OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
" tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE("
"((maxBs * tokenNeedSizeDispatch) + (maxBs * tokenNeedSizeCombine * k) + 3MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.",
maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k,
actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE),
return ge::GRAPH_FAILED);
tilingData->moeCombineNormalInfo.totalWinSize = maxWindowSize;
OPS_CHECK(SetWorkspace(context, nodeName) != ge::GRAPH_SUCCESS,
OPS_LOG_E(context->GetNodeName(), "Tiling set workspace Failed"),
return ge::GRAPH_FAILED);
SetHCommCfg(context, tilingData, groupEp, groupTp);
uint64_t tpWorldSize = static_cast<uint64_t>(tilingData->moeCombineNormalInfo.tpWorldSize);
uint32_t blockDim = 1U;
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint64_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSize = 0UL;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
context->SetBlockDim(blockDim);
tilingData->moeCombineNormalInfo.aivNum = aivNum;
tilingData->moeCombineNormalInfo.totalUbSize = ubSize;
context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously
OPS_LOG_D(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize);
PrintTilingDataInfo(nodeName, *tilingData);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus MoeCombineNormalTilingFunc(gert::TilingContext* context)
{
// recvX data type int32 is not supported
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
const char *nodeName = context->GetNodeName();
OPS_CHECK(recvXDesc == nullptr, OPS_LOG_E(nodeName, "recvXDesc is null."), return ge::GRAPH_FAILED);
// Check if recvX data type is DT_INT32
OPS_CHECK((recvXDesc->GetDataType() == ge::DT_INT32),
OPS_LOG_E(nodeName, "recvX dataType is invalid, dataType should be bf16 or float16, but is "),
return ge::GRAPH_FAILED);
ge::graphStatus ret = MoeCombineNormalA3TilingFuncImpl(context);
return ret;
}
struct MoeCombineNormalCompileInfo {};
ge::graphStatus TilingParseForMoeCombineNormal(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(MoeCombineNormal)
.Tiling(MoeCombineNormalTilingFunc)
.TilingParse<MoeCombineNormalCompileInfo>(TilingParseForMoeCombineNormal);
} // namespace optiling

View File

@@ -1,22 +0,0 @@
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "moe_combine_normal.h"
#include "moe_combine_normal_tiling.h"
using namespace AscendC;
using namespace MoeCombineNormalImpl;
extern "C" __global__ __aicore__ void moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount,
GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut,
GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(MoeCombineNormalTilingData);
TPipe pipe;
#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
GET_TILING_DATA_WITH_STRUCT(MoeCombineNormalTilingData, tilingData, tilingGM);
MoeCombineNormal<DTYPE_RECV_X, DTYPE_X, int32_t> op;
op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, workspaceGM, &pipe, &tilingData);
op.Process();
#endif
}

View File

@@ -1,377 +0,0 @@
#ifndef MOE_COMBINE_NORMAL_H
#define MOE_COMBINE_NORMAL_H
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "../common/moe_distribute_base.h"
#include "moe_combine_normal_tiling.h"
namespace MoeCombineNormalImpl {
constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U;
constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U;
constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U;
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL;
constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL;
constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U;
constexpr uint32_t UB_32_ALIGN = 32U;
constexpr uint32_t MUL_256_ALIGN = 256U;
constexpr uint64_t WIN_512_ALIGN = 512UL;
constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U;
constexpr uint8_t DOUBLE_BUFFER = 2;
template<AscendC::HardEvent event>
__aicore__ inline void SyncFunc()
{
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
AscendC::SetFlag<event>(eventID);
AscendC::WaitFlag<event>(eventID);
}
#define TemplateMC2TypeClass typename RecvXType, typename XType, typename SrcInfoType
#define TemplateMC2TypeFunc RecvXType, XType, SrcInfoType
using namespace AscendC;
template <TemplateMC2TypeClass>
class MoeCombineNormal {
public:
__aicore__ inline MoeCombineNormal() {};
__aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights,
GM_ADDR tpRecvCount,GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
const MoeCombineNormalTilingData *tilingData);
__aicore__ inline void Process();
private:
__aicore__ inline void InitMagic();
__aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount,
GM_ADDR topkWeights, GM_ADDR XOut);
__aicore__ inline void InitTilingData(const MoeCombineNormalTilingData *tilingData);
__aicore__ inline void InitBuffLen();
__aicore__ inline void CopyBufferToShareAndSetStatus();
__aicore__ inline void CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex);
__aicore__ inline void ReadBufferFromRemote();
__aicore__ inline void WaitBuffCopy(uint32_t tokenIndex);
__aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId);
__aicore__ inline void ReadBufferAndWeightedSum(uint32_t tokenIndex, uint32_t startTokenIndex);
__aicore__ GM_ADDR GetStateAddrByRankId(const int32_t rankId)
{
GM_ADDR bufferAddr;
if (epRankId_ == rankId) {
bufferAddr = (GM_ADDR)epWinContext_->localWindowsIn;
} else {
bufferAddr = (GM_ADDR)((HcclRankRelationResV2 *)epWinContext_->remoteRes[rankId].nextDevicePtr)->windowsIn;
}
return (GM_ADDR)(bufferAddr + winDataSizeOffset_);
}
__aicore__ GM_ADDR GetBufferAddrByRankId(const int32_t rankId)
{
return GetStateAddrByRankId(rankId) + COMBINE_STATE_WIN_OFFSET;
}
__aicore__ inline void SplitCoreCal(uint32_t totalNum, uint32_t &perCoreNum, uint32_t &startIdx, uint32_t &endIdx)
{
perCoreNum = totalNum / aivNum_;
uint32_t remainderRankNum = totalNum % aivNum_;
startIdx = perCoreNum * coreIdx_;
if (coreIdx_ < remainderRankNum) {
perCoreNum++;
startIdx += coreIdx_;
} else {
startIdx += remainderRankNum;
}
endIdx = startIdx + perCoreNum;
}
__gm__ HcclOpResParam *epWinContext_{nullptr};
__gm__ HcclOpResParam *tpWinContext_{nullptr};
uint32_t axisBS_{0};
uint32_t axisH_{0};
uint32_t axisK_{0};
uint32_t aivNum_{0};
uint32_t epWorldSize_{0};
uint32_t epRankId_{0};
uint32_t coreIdx_{0};
uint32_t moeExpertNum_{0};
uint32_t moeExpertPerRankNum_{0};
uint32_t magic_{0};
uint64_t winDataSizeOffset_{0};
uint32_t selfSendCnt_{0};
uint32_t hRecvXTypeLen_{0};
uint32_t h32AlignFloatLen_{0};
uint32_t h256AlignFloatLen_{0};
uint32_t h32AlignRecvXLen_{0};
uint32_t h512AlignRecvXLen_{0};
TPipe *tpipe_{nullptr};
TQue<QuePosition::VECIN, 1> weightedSumQueue_;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> localCopyQueue_;
TBuf<> stateBuf_;
TBuf<> topkWeightsBuf_;
TBuf<> tokenFloatBuf_;
TBuf<> sumFloatBuf_;
TBuf<> weightedMulBuf_;
TBuf<> srcInfoBuf_;
TBuf<> xOutBuf_;
TBuf<> tempStateBuf_;
GlobalTensor<RecvXType> recvXGM_;
GlobalTensor<SrcInfoType> tokenSrcInfoGM_;
GlobalTensor<SrcInfoType> epRecvCountGM_;
GlobalTensor<float> topkWeightsGM_;
GlobalTensor<XType> xOutGlobal_;
GM_ADDR localRankGM_;
GM_ADDR workspaceGM_;
};
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitMagic()
{
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
epWinContext_ = (__gm__ HcclOpResParam*)contextGM0;
GlobalTensor<int32_t> selfMagicTensor;
selfMagicTensor.SetGlobalBuffer((__gm__ int32_t*)((GM_ADDR)epWinContext_->localWindowsExp + MAGIC_WIN_OFFSET +
coreIdx_ * WIN_512_ALIGN));
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfMagicTensor);
magic_ = selfMagicTensor(0);
selfMagicTensor(0) = ((magic_ == 0) ? 1 : 0);
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfMagicTensor);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(
GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR XOut)
{
recvXGM_.SetGlobalBuffer((__gm__ RecvXType*)recvX);
tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType*)tokenSrcInfo);
epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t*)epRecvCount);
topkWeightsGM_.SetGlobalBuffer((__gm__ float*)topkWeights);
xOutGlobal_.SetGlobalBuffer((__gm__ XType*)XOut);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitTilingData(const MoeCombineNormalTilingData *tilingData)
{
axisBS_ = tilingData->moeCombineNormalInfo.bs;
axisH_ = tilingData->moeCombineNormalInfo.h;
axisK_ = tilingData->moeCombineNormalInfo.k;
aivNum_ = tilingData->moeCombineNormalInfo.aivNum;
moeExpertNum_ = tilingData->moeCombineNormalInfo.moeExpertNum;
moeExpertPerRankNum_ = tilingData->moeCombineNormalInfo.moeExpertPerRankNum;
epWorldSize_ = tilingData->moeCombineNormalInfo.epWorldSize;
epRankId_ = tilingData->moeCombineNormalInfo.epRankId;
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::InitBuffLen()
{
uint32_t hFloatSize = axisH_ * static_cast<uint32_t>(sizeof(float));
h32AlignFloatLen_ = Ceil(hFloatSize, UB_32_ALIGN) * UB_32_ALIGN;
h256AlignFloatLen_ = Ceil(hFloatSize, MUL_256_ALIGN) * MUL_256_ALIGN;
hRecvXTypeLen_ = axisH_ * sizeof(RecvXType);
h32AlignRecvXLen_ = Ceil(hRecvXTypeLen_, UB_32_ALIGN) * UB_32_ALIGN;
h512AlignRecvXLen_ = Ceil(hRecvXTypeLen_, WIN_512_ALIGN) * WIN_512_ALIGN;
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo,
GM_ADDR epRecvCount, GM_ADDR topkWeights,
GM_ADDR tpRecvCount, GM_ADDR XOut,
GM_ADDR workspaceGM, TPipe *pipe,
const MoeCombineNormalTilingData *tilingData)
{
workspaceGM_ = workspaceGM;
tpipe_ = pipe;
coreIdx_ = GetBlockIdx();
InitMagic();
InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, XOut);
InitTilingData(tilingData);
InitBuffLen();
PipeBarrier<PIPE_ALL>();
winDataSizeOffset_ = static_cast<uint64_t>(magic_) * (tilingData->moeCombineNormalInfo.totalWinSize / 2UL);
localRankGM_ = GetBufferAddrByRankId(epRankId_);
DataCacheCleanAndInvalid<SrcInfoType, CacheLine::SINGLE_CACHE_LINE,
DcciDst::CACHELINE_OUT>(epRecvCountGM_[moeExpertNum_ - 1]);
selfSendCnt_ = epRecvCountGM_(moeExpertNum_ - 1);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShareAndSetStatus()
{
PipeBarrier<PIPE_ALL>();
uint32_t perBlockSendNum = 0, startTokenId = 0, endTokenId = 0;
SplitCoreCal(selfSendCnt_, perBlockSendNum, startTokenId, endTokenId);
if (perBlockSendNum == 0U) {
return;
}
uint32_t blockLen = static_cast<uint32_t>(perBlockSendNum * TOKEN_SRC_INFO_LEN * sizeof(uint32_t));
tpipe_->Reset();
tpipe_->InitBuffer(stateBuf_, UB_32_ALIGN);
tpipe_->InitBuffer(localCopyQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_);
tpipe_->InitBuffer(srcInfoBuf_, blockLen);
LocalTensor<uint32_t> statusTensor = stateBuf_.AllocTensor<uint32_t>();
Duplicate<uint32_t>(statusTensor, 0x3F800000, FLOAT_NUM_PER_ALIGN);
LocalTensor<SrcInfoType> srcInfoLocal = srcInfoBuf_.Get<SrcInfoType>();
const DataCopyExtParams dataCopyParams{1U, blockLen, 0U, 0U, 0U};
const DataCopyPadExtParams<SrcInfoType> padParams{false, 0U, 0U, 0U};
DataCopyPad(srcInfoLocal, tokenSrcInfoGM_[startTokenId * TOKEN_SRC_INFO_LEN], dataCopyParams, padParams);
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) {
uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN;
uint32_t srcRankId = static_cast<uint32_t>(srcInfoLocal(index + RANK_ID_OFFSET_IN_SRC_INFO));
uint32_t srcTokenId = static_cast<uint32_t>(srcInfoLocal(index + TOKEN_IDX_OFFSET_IN_SRC_INFO));
uint32_t srcTopkId = static_cast<uint32_t>(srcInfoLocal(index + TOPK_IDX_OFFSET_IN_SRC_INFO));
CopyBufferToShare(srcRankId, srcTokenId, srcTopkId, tokenIndex);
PipeBarrier<PIPE_ALL>();
SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId);
}
SyncFunc<AscendC::HardEvent::MTE3_S>();
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId,
uint32_t srcTopkId, uint32_t tkIndex)
{
uint32_t tokenOffset = tkIndex * axisH_;
GM_ADDR dstGM = GetBufferAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_;
GlobalTensor<XType> dstWindow;
dstWindow.SetGlobalBuffer((__gm__ XType*)dstGM);
DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
LocalTensor<RecvXType> localCopyTensor;
localCopyTensor = localCopyQueue_.AllocTensor<RecvXType>();
DataCopyPad(localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams);
localCopyQueue_.EnQue(localCopyTensor);
localCopyTensor = localCopyQueue_.DeQue<RecvXType>();
DataCopyPad(dstWindow, localCopyTensor, xOutCopyParams);
localCopyQueue_.FreeTensor<RecvXType>(localCopyTensor);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId,
uint32_t srcTopkId)
{
LocalTensor<uint32_t> statusTensor = stateBuf_.AllocTensor<uint32_t>();
GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN;
GlobalTensor<uint32_t> stateGMTensor;
stateGMTensor.SetGlobalBuffer((__gm__ uint32_t*)stateGM);
DataCopy<uint32_t>(stateGMTensor, statusTensor, FLOAT_NUM_PER_ALIGN);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::WaitBuffCopy(uint32_t tokenIndex)
{
uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN;
GM_ADDR stateGM = GetStateAddrByRankId(epRankId_) + tokenIndex * axisK_ * UB_32_ALIGN; // Calculate address offset
GlobalTensor<float> stateGMTensor;
stateGMTensor.SetGlobalBuffer((__gm__ float*)stateGM);
float current = (float)0.0;
float target = (float)1.0 * axisK_ * FLOAT_NUM_PER_ALIGN;
SumParams sumPerKParams{1, calCount, calCount};
LocalTensor<float> stateTensorLocal = stateBuf_.Get<float>();
LocalTensor<float> tempStateTensorLocal = tempStateBuf_.Get<float>();
while (current != target) {
SyncFunc<AscendC::HardEvent::S_MTE2>();
DataCopy<float>(stateTensorLocal, stateGMTensor, calCount);
SyncFunc<AscendC::HardEvent::MTE2_V>();
Sum(tempStateTensorLocal, stateTensorLocal, sumPerKParams);
SyncFunc<AscendC::HardEvent::V_S>();
current = tempStateTensorLocal(0);
}
SyncFunc<AscendC::HardEvent::S_V>();
Duplicate<float>(tempStateTensorLocal, (float)0.0, calCount);
SyncFunc<AscendC::HardEvent::V_MTE3>();
DataCopy<float>(stateGMTensor, tempStateTensorLocal, calCount);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferAndWeightedSum(uint32_t tokenIndex,
uint32_t startTokenIndex)
{
LocalTensor<float> tokenFloatLocal = tokenFloatBuf_.Get<float>();
LocalTensor<float> weightedMulBufLocal = weightedMulBuf_.Get<float>();
LocalTensor<float> sumFloatBufLocal = sumFloatBuf_.Get<float>();
LocalTensor<float> topkWeightsLocal = topkWeightsBuf_.Get<float>();
LocalTensor<uint32_t> stateTensorLocal = stateBuf_.Get<uint32_t>();
Duplicate(sumFloatBufLocal, static_cast<float>(0), axisH_);
const DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
for (uint32_t topkId = 0U; topkId < axisK_; topkId++) {
float scale = topkWeightsLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId);
GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_;
GlobalTensor<XType> localTokenTensor;
localTokenTensor.SetGlobalBuffer((__gm__ XType*)localTokenAddr);
LocalTensor<XType> tmpToken = weightedSumQueue_.AllocTensor<XType>();
const DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
DataCopyPad(tmpToken, localTokenTensor, xOutCopyParams, copyPadExtParams);
weightedSumQueue_.EnQue(tmpToken);
tmpToken = weightedSumQueue_.DeQue<XType>();
Cast(tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_);
PipeBarrier<PIPE_V>();
AscendC::Muls(weightedMulBufLocal, tokenFloatLocal, scale, axisH_);
PipeBarrier<PIPE_V>();
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, weightedMulBufLocal, axisH_);
weightedSumQueue_.FreeTensor<XType>(tmpToken);
}
PipeBarrier<PIPE_V>();
LocalTensor<XType> xOutLocal = xOutBuf_.Get<XType>();
Cast(xOutLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, axisH_);
SyncFunc<AscendC::HardEvent::V_MTE3>();
DataCopyPad(xOutGlobal_[tokenIndex * axisH_], xOutLocal, xOutCopyParams);
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferFromRemote()
{
if (axisBS_ == 0U) {
return;
}
uint32_t tokenPerBlock = 0U, startTokenIndex = 0U, endTokenIndex = 0U;
SplitCoreCal(axisBS_, tokenPerBlock, startTokenIndex, endTokenIndex);
if (tokenPerBlock == 0U) {
return;
}
tpipe_->Reset();
tpipe_->InitBuffer(xOutBuf_, h32AlignRecvXLen_);
tpipe_->InitBuffer(tokenFloatBuf_, h32AlignFloatLen_);
tpipe_->InitBuffer(weightedMulBuf_, h256AlignFloatLen_);
tpipe_->InitBuffer(sumFloatBuf_, h32AlignFloatLen_);
tpipe_->InitBuffer(weightedSumQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_);
tpipe_->InitBuffer(stateBuf_, (axisK_) * UB_32_ALIGN);
tpipe_->InitBuffer(tempStateBuf_, (axisK_) * UB_32_ALIGN);
tpipe_->InitBuffer(topkWeightsBuf_, tokenPerBlock * axisK_ * sizeof(float));
LocalTensor<float> topkWeightsLocal = topkWeightsBuf_.Get<float>();
const DataCopyExtParams bskParams{1U, static_cast<uint32_t>(tokenPerBlock * axisK_ * sizeof(float)), 0U, 0U, 0U};
const DataCopyPadExtParams<float> copyPadFloatParams{false, 0U, 0U, 0U};
DataCopyPad(topkWeightsLocal, topkWeightsGM_[startTokenIndex * axisK_], bskParams, copyPadFloatParams);
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t tokenIndex = startTokenIndex; tokenIndex < endTokenIndex; tokenIndex++) {
WaitBuffCopy(tokenIndex);
SyncFunc<AscendC::HardEvent::MTE3_V>(); // Sync with result datacopy on same tensor
ReadBufferAndWeightedSum(tokenIndex, startTokenIndex);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void MoeCombineNormal<TemplateMC2TypeFunc>::Process()
{
if ASCEND_IS_AIV { // All AIV processing
CopyBufferToShareAndSetStatus();
ReadBufferFromRemote();
}
}
} // MoeCombineNormalImpl
#endif // MOE_COMBINE_IMPL_H

View File

@@ -1,33 +0,0 @@
#ifndef MOE_COMBINE_NORMAL_TILING_H
#define MOE_COMBINE_NORMAL_TILING_H
#include <cstdint>
#include "kernel_tiling/kernel_tiling.h"
// a3
struct MoeCombineNormalInfo {
uint32_t epWorldSize;
uint32_t tpWorldSize;
uint32_t epRankId;
uint32_t tpRankId;
uint32_t expertShardType;
uint32_t moeExpertNum;
uint32_t moeExpertPerRankNum;
uint32_t globalBs;
uint32_t bs;
uint32_t k;
uint32_t h;
uint32_t aivNum;
uint64_t totalUbSize;
uint64_t totalWinSize;
float armAvgFactor;
float epsilon;
};
struct MoeCombineNormalTilingData {
Mc2InitTiling mc2InitTiling;
Mc2CcTiling mc2CcTiling1;
Mc2CcTiling mc2CcTiling2;
MoeCombineNormalInfo moeCombineNormalInfo;
};
#endif //MOE_COMBINE_NORMAL_TILING_H