[Kernel] Add moe_gating_top_k operator support for Ascend NPU (#5579)

### What this PR does / why we need it?

1.replace moe_gating_top_k from torch_npu with custom op
2.enable the  renorm function of moe_gating_top_k in softmax scenerio

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
No need test

- vLLM version: v0.13.0
- vLLM main:
7157596103

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
This commit is contained in:
ZCG12345
2026-01-07 21:42:31 +08:00
committed by GitHub
parent 1165b2c863
commit 3be8e33fe9
32 changed files with 4667 additions and 13 deletions

View File

@@ -0,0 +1,42 @@
# ----------------------------------------------------------------------------
# This program is free software, you can redistribute it and/or modify.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------
add_ops_compile_options(
OP_NAME MoeGatingTopK
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
# Host aclnn
if (BUILD_OPEN_PROJECT)
target_sources(op_host_aclnn PRIVATE
moe_gating_top_k_def.cpp
)
# Tiling
target_sources(optiling PRIVATE
moe_gating_top_k_tiling.cpp
moe_gating_top_k_tiling_base.cpp
moe_gating_top_k_tiling_arch35.cpp
)
target_sources(opsproto PRIVATE
moe_gating_top_k_proto.cpp
moe_gating_top_k_infershape.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
endif()

View File

@@ -0,0 +1,56 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,61 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file math_util.h
* \brief
*/
#ifndef TILING_MATMUL_MATH_UTIL_H
#define TILING_MATMUL_MATH_UTIL_H
#include <array>
#include <cstdint>
#include <vector>
#include <utility>
namespace matmul_tiling {
class MathUtil {
public:
static bool IsEqual(float leftValue, float rightValue);
template<typename T>
static auto CeilDivision(T num1, T num2) -> T
{
if (num2 == 0) {
return 0;
}
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
static_cast<int64_t>(num2));
}
template<typename T>
static auto Align(T num1, T num2) -> T
{
return CeilDivision(num1, num2) * num2;
}
static int32_t AlignDown(int32_t num1, int32_t num2);
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static bool CheckFactorNumSatisfy(const int32_t dim);
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
bool isKDim);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
const int32_t coreNum, const int32_t maxNum);
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
};
} // namespace matmul_tiling
#endif // _MATH_UTIL_H_

View File

@@ -0,0 +1,71 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class MoeGatingTopK : public OpDef {
public:
explicit MoeGatingTopK(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("expert_idx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("k").Int();
this->Attr("k_group").AttrType(OPTIONAL).Int(1);
this->Attr("group_count").AttrType(OPTIONAL).Int(1);
this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0);
this->Attr("renorm").AttrType(OPTIONAL).Int(0);
this->Attr("norm_type").AttrType(OPTIONAL).Int(0);
this->Attr("out_flag").AttrType(OPTIONAL).Bool(false);
this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0);
this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
OpAICoreConfig regbaseCfg;
regbaseCfg.DynamicCompileStaticFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt");
this->AICore().AddConfig("ascend910_95", regbaseCfg);
}
};
OP_ADD(MoeGatingTopK);
} // namespace ops

View File

@@ -0,0 +1,147 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_infershape.cpp
* \brief
*/
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "error_log.h"
#include <string>
#include <string>
#define TO_STRING(x) std::string(#x)
using namespace ge;
namespace ops {
static constexpr size_t DIM_ONE = 1;
static constexpr size_t DIM_TWO = 2;
static constexpr int64_t NEG_ONE = -1;
static constexpr int64_t X_INDEX = 0;
static constexpr int64_t BIAS_INDEX = 1;
static constexpr int64_t Y_INDEX = 0;
static constexpr int64_t EXPERT_IDX_INDEX = 1;
static constexpr int64_t OUT_INDEX = 2;
static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape)
{
int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
if (XRows < NEG_ONE || expertNum < NEG_ONE) {
OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape,
const int64_t k)
{
if (xShape->GetDimNum() == 1U) {
if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) {
OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
} else if (xShape->GetDimNum() != DIM_TWO) {
OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
if (k < 0) {
OP_LOGE(context, "k must be a non-negative number.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k)
{
OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str());
OP_LOGD(context, "k is: %ld.", k);
}
static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape,
const gert::Shape *expertIdxShape, const gert::Shape *outShape)
{
OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str());
OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str());
OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str());
}
static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInfershape.");
// 获取输入shape
const gert::Shape *xShape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
gert::Shape *yShape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
gert::Shape *expertIdxShape = context->GetOutputShape(1);
OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape);
gert::Shape *outShape = context->GetOutputShape(2);
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
// 获取attr
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, kPtr);
const int64_t k = *kPtr;
ShowInputShapeInfo(context, xShape, k);
// 参数校验
if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
yShape->SetDimNum(DIM_TWO);
yShape->SetDim(0U, rows);
yShape->SetDim(1U, k);
expertIdxShape->SetDimNum(DIM_TWO);
expertIdxShape->SetDim(0U, rows);
expertIdxShape->SetDim(1U, k);
outShape->SetDimNum(DIM_TWO);
outShape->SetDim(0U, rows);
outShape->SetDim(1U, expertNum);
ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape);
OP_LOGD(context, "End to do MoeGatingTopKInfershape.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType.");
auto xDtype = context->GetInputDataType(0);
context->SetOutputDataType(Y_INDEX, xDtype);
context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32);
context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT);
OP_LOGD(context, "End to do MoeGatingTopKInferDataType.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK);
} // namespace ops

View File

@@ -0,0 +1,15 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#include "moe_gating_top_k_proto.h"

View File

@@ -0,0 +1,66 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#include "graph/operator_reg.h"
namespace ge {
/**
* @brief Compute renorm(sigmoid) and topk for moe input.
*
* @par Inputs:
* @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count.
* @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x.
*
* @par Outputs:
* @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x.
The size of the non-1 axis must be the same as that of the corresponding axis of x.
The size of the -1 axis must be the same as that of k.
* @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y.
* @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x.
*
* @par Attributes:
* @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value.
* @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1.
* @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1.
* @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0.
* @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0.
* @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0.
* @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false.
* @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0.
* @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20.
*/
REG_OP(MoeGatingTopK)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(expert_idx, TensorType({DT_INT32}))
.OUTPUT(out, TensorType({DT_FLOAT}))
.REQUIRED_ATTR(k, Int)
.ATTR(k_group, Int, 1)
.ATTR(group_count, Int, 1)
.ATTR(group_select_mode, Int, 0)
.ATTR(renorm, Int, 0)
.ATTR(norm_type, Int, 0)
.ATTR(out_flag, Bool, false)
.ATTR(routed_scaling_factor, Float, 1.0)
.ATTR(eps, Float, 1e-20f)
.OP_END_FACTORY_REG(MoeGatingTopK)
} // namespace ge
#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_

View File

@@ -0,0 +1,573 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling.cpp
* \brief
*/
#include <cmath>
#include "register/op_def_registry.h"
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "platform/platform_info.h"
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216;
const static uint32_t DATATYPESIZE_FLOAT = 4;
const static bool IS_LARGEST = true;
const static bool IS_INITINDEX = false;
const static bool IS_REUSESOURCE = false;
const static uint64_t WITH_GROUP_CONDITION = 1;
const static uint64_t WITHOUT_GROUP_CONDITION = 2;
const static uint64_t MAX_IN_GROUP_CONDITION = 3;
constexpr int32_t ROW_COUNT_PER_TASK = 1;
const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0;
const static uint64_t TILING_KEY_WITHOUT_GROUP = 1;
const static uint64_t TILING_KEY_GENERALIZED = 2;
inline static int64_t CeilLog4(int64_t x)
{
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(4))); // 4 for four
}
class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingBase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
return true;
}
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void SplitRows();
void CalTmpBufUbSize();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_ = 0;
int64_t expertCount_ = 0;
int64_t addBias_ = 0;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKTilingData moeGatingTopKTilingData_;
};
ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckAttr()
{
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
perGroupExpertCount_ = expertCount_ / groupCount_;
OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_);
OP_CHECK_IF(perGroupExpertCount_ < 1,
OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
perGroupExpertCount_, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_,
OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_),
return ge::GRAPH_FAILED);
int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L);
OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign);
if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) {
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_);
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str());
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
if (biasShape_ != nullptr) {
OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str());
}
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str());
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str());
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr);
outShape_ = &outPtr->GetStorageShape();
if (outShape_ != nullptr) {
OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str());
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str());
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str());
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str());
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc);
auto normOutDtype = normOutDesc->GetDataType();
OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.",
ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()),
return ge::GRAPH_FAILED);
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
OP_LOGI(context_, "Attr k is: %ld", k_);
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
OP_LOGI(context_, "Attr k_group is: %ld", kGroup_);
moeGatingTopKTilingData_.set_kGroup(kGroup_);
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
OP_LOGI(context_, "Attr group_count is: %ld", groupCount_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_);
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
}
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
OP_LOGI(context_, "Attr renorm is: %ld", renorm_);
moeGatingTopKTilingData_.set_renorm(renorm_);
}
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
OP_LOGI(context_, "Attr norm_type is: %ld", normType_);
moeGatingTopKTilingData_.set_normType(normType_);
}
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_);
moeGatingTopKTilingData_.set_outFlag(outFlag_);
}
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_);
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
}
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
OP_LOGI(context_, "Attr eps is: %f", eps_);
moeGatingTopKTilingData_.set_eps(eps_);
}
OP_LOGI(context_, "Attr eps is: %f ", eps_);
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape()
{
OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str());
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingBase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
// perCoreRows cannot be 0
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L));
OP_LOGI(context_, "vms count is: %ld", vmsCount);
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
}
void MoeGatingTopKTilingBase::CalTmpBufUbSize()
{
std::vector<int64_t> shape_vec = {expertCount_};
ge::Shape shape(shape_vec);
uint32_t maxValue = 0;
uint32_t minValue = 0;
AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue);
int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast<int64_t>(sizeof(float));
moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast<int64_t>(minValue)));
}
ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling()
{
OP_LOGI(context_, "DoOpTiling: start");
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize()
{
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingBase::GetTilingKey() const
{
if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ &&
groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID &&
!outFlag_) {
return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF;
} else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) {
return TILING_KEY_WITHOUT_GROUP;
} else {
return TILING_KEY_GENERALIZED;
}
}
void MoeGatingTopKTilingBase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000);
} // namespace optiling

View File

@@ -0,0 +1,86 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_tiling.h
* \brief
*/
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#include <cmath>
#include <cstdint>
#include <vector>
#include <algorithm>
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "register/op_def_registry.h"
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "platform/platform_infos_def.h"
#include "math_util.h"
//#include "util/extern_math_util.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData)
BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData)
struct MoeGatingTopKCompileInfo {};
} // namespace optiling
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H

View File

@@ -0,0 +1,521 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_arch35.cpp
* \brief
*/
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "platform/platform_info.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000;
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t MRGSORT_SIZE = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast<int64_t>(16 * 1024 * 1024); // 预留16M空间
class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingRegbase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) {
return false;
}
return true;
}
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void CalTmpBufUbSize();
void SplitRows();
void Tiling4GatherOutComputeSplitK();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_;
int64_t expertCount_;
int64_t addBias_ = 0;
int64_t k_;
int64_t kGroup_ = 1;
int64_t groupCount_ = 1;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_;
platform_ascendc::SocVersion socVersion;
};
ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.",
biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr()
{
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_, k_),
return ge::GRAPH_FAILED);
if (kGroup_ == groupCount_ || groupCount_ == expertCount_) {
kGroup_ = 1;
groupCount_ = 1;
}
moeGatingTopKTilingData_.set_kGroup(kGroup_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
int64_t groupExpertCount = expertCount_ / groupCount_;
int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L);
moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ * groupExpertCount < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_ * groupExpertCount, k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupExpertCount < 1,
OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
groupExpertCount, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
if (outPtr != nullptr) {
outShape_ = &outPtr->GetStorageShape();
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
}
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
}
moeGatingTopKTilingData_.set_renorm(renorm_);
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
}
moeGatingTopKTilingData_.set_normType(normType_);
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
}
moeGatingTopKTilingData_.set_outFlag(outFlag_);
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
}
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
}
moeGatingTopKTilingData_.set_eps(eps_);
OP_LOGI(context_, "Attr eps is: %f ", eps_);
auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
if (outFlag_ && outDesc != nullptr) {
auto outDtype = outDesc->GetDataType();
OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.",
ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()),
return ge::GRAPH_FAILED);
}
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
socVersion = ascendcPlatform.GetSocVersion();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape()
{
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() {
std::vector<int64_t> shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()};
ge::Shape softmaxShape(shape_vec);
uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true);
AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData);
}
void MoeGatingTopKTilingRegbase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
if (perCoreRows == 0) {
OP_LOGE(context_, "perCoreRows can't be 0.");
return;
}
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = 0;
if (kGroup_ > MRGSORT_SIZE) {
int64_t index = MRGSORT_SIZE;
while (index < kGroup_) {
index = index * MRGSORT_SIZE;
vmsCount++;
}
}
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling()
{
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize()
{
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const
{
return MOE_GATING_TOP_K_REGBASE_TILING_KEY;
}
void MoeGatingTopKTilingRegbase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000);
} // namespace optiling

View File

@@ -0,0 +1,38 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_base.cpp
* \brief
*/
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "error_log.h"
#include "kernel_tiling/kernel_tiling.h"
namespace optiling {
static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context)
{
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
}
static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(MoeGatingTopK)
.Tiling(TilingForMoeGatingTopK)
.TilingParse<MoeGatingTopKCompileInfo>(TilingPrepareForMoeGatingTopK);
} // namespace optiling