[Kernel] Add moe_gating_top_k operator support for Ascend NPU (#5579)

### What this PR does / why we need it?

1.replace moe_gating_top_k from torch_npu with custom op
2.enable the  renorm function of moe_gating_top_k in softmax scenerio

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
No need test

- vLLM version: v0.13.0
- vLLM main:
7157596103

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
This commit is contained in:
ZCG12345
2026-01-07 21:42:31 +08:00
committed by GitHub
parent 1165b2c863
commit 3be8e33fe9
32 changed files with 4667 additions and 13 deletions

View File

@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -70,6 +70,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"dispatch_layout"
"notify_dispatch"
"moe_init_routing_custom"
"moe_gating_top_k"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"

View File

@@ -0,0 +1,42 @@
# ----------------------------------------------------------------------------
# This program is free software, you can redistribute it and/or modify.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------
add_ops_compile_options(
OP_NAME MoeGatingTopK
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
# Host aclnn
if (BUILD_OPEN_PROJECT)
target_sources(op_host_aclnn PRIVATE
moe_gating_top_k_def.cpp
)
# Tiling
target_sources(optiling PRIVATE
moe_gating_top_k_tiling.cpp
moe_gating_top_k_tiling_base.cpp
moe_gating_top_k_tiling_arch35.cpp
)
target_sources(opsproto PRIVATE
moe_gating_top_k_proto.cpp
moe_gating_top_k_infershape.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
endif()

View File

@@ -0,0 +1,56 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,61 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file math_util.h
* \brief
*/
#ifndef TILING_MATMUL_MATH_UTIL_H
#define TILING_MATMUL_MATH_UTIL_H
#include <array>
#include <cstdint>
#include <vector>
#include <utility>
namespace matmul_tiling {
class MathUtil {
public:
static bool IsEqual(float leftValue, float rightValue);
template<typename T>
static auto CeilDivision(T num1, T num2) -> T
{
if (num2 == 0) {
return 0;
}
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
static_cast<int64_t>(num2));
}
template<typename T>
static auto Align(T num1, T num2) -> T
{
return CeilDivision(num1, num2) * num2;
}
static int32_t AlignDown(int32_t num1, int32_t num2);
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static bool CheckFactorNumSatisfy(const int32_t dim);
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
bool isKDim);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
const int32_t coreNum, const int32_t maxNum);
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
};
} // namespace matmul_tiling
#endif // _MATH_UTIL_H_

View File

@@ -0,0 +1,71 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class MoeGatingTopK : public OpDef {
public:
explicit MoeGatingTopK(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("expert_idx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("k").Int();
this->Attr("k_group").AttrType(OPTIONAL).Int(1);
this->Attr("group_count").AttrType(OPTIONAL).Int(1);
this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0);
this->Attr("renorm").AttrType(OPTIONAL).Int(0);
this->Attr("norm_type").AttrType(OPTIONAL).Int(0);
this->Attr("out_flag").AttrType(OPTIONAL).Bool(false);
this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0);
this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
OpAICoreConfig regbaseCfg;
regbaseCfg.DynamicCompileStaticFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt");
this->AICore().AddConfig("ascend910_95", regbaseCfg);
}
};
OP_ADD(MoeGatingTopK);
} // namespace ops

View File

@@ -0,0 +1,147 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_infershape.cpp
* \brief
*/
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "error_log.h"
#include <string>
#include <string>
#define TO_STRING(x) std::string(#x)
using namespace ge;
namespace ops {
static constexpr size_t DIM_ONE = 1;
static constexpr size_t DIM_TWO = 2;
static constexpr int64_t NEG_ONE = -1;
static constexpr int64_t X_INDEX = 0;
static constexpr int64_t BIAS_INDEX = 1;
static constexpr int64_t Y_INDEX = 0;
static constexpr int64_t EXPERT_IDX_INDEX = 1;
static constexpr int64_t OUT_INDEX = 2;
static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape)
{
int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
if (XRows < NEG_ONE || expertNum < NEG_ONE) {
OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape,
const int64_t k)
{
if (xShape->GetDimNum() == 1U) {
if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) {
OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
} else if (xShape->GetDimNum() != DIM_TWO) {
OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
if (k < 0) {
OP_LOGE(context, "k must be a non-negative number.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k)
{
OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str());
OP_LOGD(context, "k is: %ld.", k);
}
static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape,
const gert::Shape *expertIdxShape, const gert::Shape *outShape)
{
OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str());
OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str());
OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str());
}
static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInfershape.");
// 获取输入shape
const gert::Shape *xShape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
gert::Shape *yShape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
gert::Shape *expertIdxShape = context->GetOutputShape(1);
OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape);
gert::Shape *outShape = context->GetOutputShape(2);
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
// 获取attr
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, kPtr);
const int64_t k = *kPtr;
ShowInputShapeInfo(context, xShape, k);
// 参数校验
if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
yShape->SetDimNum(DIM_TWO);
yShape->SetDim(0U, rows);
yShape->SetDim(1U, k);
expertIdxShape->SetDimNum(DIM_TWO);
expertIdxShape->SetDim(0U, rows);
expertIdxShape->SetDim(1U, k);
outShape->SetDimNum(DIM_TWO);
outShape->SetDim(0U, rows);
outShape->SetDim(1U, expertNum);
ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape);
OP_LOGD(context, "End to do MoeGatingTopKInfershape.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType.");
auto xDtype = context->GetInputDataType(0);
context->SetOutputDataType(Y_INDEX, xDtype);
context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32);
context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT);
OP_LOGD(context, "End to do MoeGatingTopKInferDataType.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK);
} // namespace ops

View File

@@ -0,0 +1,15 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#include "moe_gating_top_k_proto.h"

View File

@@ -0,0 +1,66 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#include "graph/operator_reg.h"
namespace ge {
/**
* @brief Compute renorm(sigmoid) and topk for moe input.
*
* @par Inputs:
* @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count.
* @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x.
*
* @par Outputs:
* @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x.
The size of the non-1 axis must be the same as that of the corresponding axis of x.
The size of the -1 axis must be the same as that of k.
* @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y.
* @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x.
*
* @par Attributes:
* @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value.
* @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1.
* @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1.
* @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0.
* @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0.
* @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0.
* @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false.
* @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0.
* @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20.
*/
REG_OP(MoeGatingTopK)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(expert_idx, TensorType({DT_INT32}))
.OUTPUT(out, TensorType({DT_FLOAT}))
.REQUIRED_ATTR(k, Int)
.ATTR(k_group, Int, 1)
.ATTR(group_count, Int, 1)
.ATTR(group_select_mode, Int, 0)
.ATTR(renorm, Int, 0)
.ATTR(norm_type, Int, 0)
.ATTR(out_flag, Bool, false)
.ATTR(routed_scaling_factor, Float, 1.0)
.ATTR(eps, Float, 1e-20f)
.OP_END_FACTORY_REG(MoeGatingTopK)
} // namespace ge
#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_

View File

@@ -0,0 +1,573 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling.cpp
* \brief
*/
#include <cmath>
#include "register/op_def_registry.h"
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "platform/platform_info.h"
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216;
const static uint32_t DATATYPESIZE_FLOAT = 4;
const static bool IS_LARGEST = true;
const static bool IS_INITINDEX = false;
const static bool IS_REUSESOURCE = false;
const static uint64_t WITH_GROUP_CONDITION = 1;
const static uint64_t WITHOUT_GROUP_CONDITION = 2;
const static uint64_t MAX_IN_GROUP_CONDITION = 3;
constexpr int32_t ROW_COUNT_PER_TASK = 1;
const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0;
const static uint64_t TILING_KEY_WITHOUT_GROUP = 1;
const static uint64_t TILING_KEY_GENERALIZED = 2;
inline static int64_t CeilLog4(int64_t x)
{
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(4))); // 4 for four
}
class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingBase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
return true;
}
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void SplitRows();
void CalTmpBufUbSize();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_ = 0;
int64_t expertCount_ = 0;
int64_t addBias_ = 0;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKTilingData moeGatingTopKTilingData_;
};
ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckAttr()
{
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
perGroupExpertCount_ = expertCount_ / groupCount_;
OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_);
OP_CHECK_IF(perGroupExpertCount_ < 1,
OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
perGroupExpertCount_, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_,
OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_),
return ge::GRAPH_FAILED);
int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L);
OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign);
if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) {
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_);
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str());
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
if (biasShape_ != nullptr) {
OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str());
}
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str());
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str());
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr);
outShape_ = &outPtr->GetStorageShape();
if (outShape_ != nullptr) {
OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str());
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str());
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str());
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str());
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc);
auto normOutDtype = normOutDesc->GetDataType();
OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.",
ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()),
return ge::GRAPH_FAILED);
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
OP_LOGI(context_, "Attr k is: %ld", k_);
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
OP_LOGI(context_, "Attr k_group is: %ld", kGroup_);
moeGatingTopKTilingData_.set_kGroup(kGroup_);
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
OP_LOGI(context_, "Attr group_count is: %ld", groupCount_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_);
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
}
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
OP_LOGI(context_, "Attr renorm is: %ld", renorm_);
moeGatingTopKTilingData_.set_renorm(renorm_);
}
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
OP_LOGI(context_, "Attr norm_type is: %ld", normType_);
moeGatingTopKTilingData_.set_normType(normType_);
}
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_);
moeGatingTopKTilingData_.set_outFlag(outFlag_);
}
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_);
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
}
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
OP_LOGI(context_, "Attr eps is: %f", eps_);
moeGatingTopKTilingData_.set_eps(eps_);
}
OP_LOGI(context_, "Attr eps is: %f ", eps_);
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape()
{
OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str());
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingBase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
// perCoreRows cannot be 0
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L));
OP_LOGI(context_, "vms count is: %ld", vmsCount);
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
}
void MoeGatingTopKTilingBase::CalTmpBufUbSize()
{
std::vector<int64_t> shape_vec = {expertCount_};
ge::Shape shape(shape_vec);
uint32_t maxValue = 0;
uint32_t minValue = 0;
AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue);
int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast<int64_t>(sizeof(float));
moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast<int64_t>(minValue)));
}
ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling()
{
OP_LOGI(context_, "DoOpTiling: start");
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize()
{
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingBase::GetTilingKey() const
{
if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ &&
groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID &&
!outFlag_) {
return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF;
} else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) {
return TILING_KEY_WITHOUT_GROUP;
} else {
return TILING_KEY_GENERALIZED;
}
}
void MoeGatingTopKTilingBase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000);
} // namespace optiling

View File

@@ -0,0 +1,86 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_tiling.h
* \brief
*/
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#include <cmath>
#include <cstdint>
#include <vector>
#include <algorithm>
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "register/op_def_registry.h"
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "platform/platform_infos_def.h"
#include "math_util.h"
//#include "util/extern_math_util.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData)
BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData)
struct MoeGatingTopKCompileInfo {};
} // namespace optiling
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H

View File

@@ -0,0 +1,521 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_arch35.cpp
* \brief
*/
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "platform/platform_info.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000;
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t MRGSORT_SIZE = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast<int64_t>(16 * 1024 * 1024); // 预留16M空间
class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingRegbase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) {
return false;
}
return true;
}
ge::graphStatus GetPlatformInfo() override;
ge::graphStatus GetShapeAttrsInfo() override;
ge::graphStatus DoOpTiling() override;
ge::graphStatus DoLibApiTiling() override;
uint64_t GetTilingKey() const override;
ge::graphStatus GetWorkspaceSize() override;
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void CalTmpBufUbSize();
void SplitRows();
void Tiling4GatherOutComputeSplitK();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_;
int64_t expertCount_;
int64_t addBias_ = 0;
int64_t k_;
int64_t kGroup_ = 1;
int64_t groupCount_ = 1;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_;
platform_ascendc::SocVersion socVersion;
};
ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.",
biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr()
{
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_, k_),
return ge::GRAPH_FAILED);
if (kGroup_ == groupCount_ || groupCount_ == expertCount_) {
kGroup_ = 1;
groupCount_ = 1;
}
moeGatingTopKTilingData_.set_kGroup(kGroup_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
int64_t groupExpertCount = expertCount_ / groupCount_;
int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L);
moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ * groupExpertCount < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_ * groupExpertCount, k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupExpertCount < 1,
OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
groupExpertCount, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
if (outPtr != nullptr) {
outShape_ = &outPtr->GetStorageShape();
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
}
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
}
moeGatingTopKTilingData_.set_renorm(renorm_);
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
}
moeGatingTopKTilingData_.set_normType(normType_);
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
}
moeGatingTopKTilingData_.set_outFlag(outFlag_);
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
}
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
}
moeGatingTopKTilingData_.set_eps(eps_);
OP_LOGI(context_, "Attr eps is: %f ", eps_);
auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
if (outFlag_ && outDesc != nullptr) {
auto outDtype = outDesc->GetDataType();
OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.",
ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()),
return ge::GRAPH_FAILED);
}
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
socVersion = ascendcPlatform.GetSocVersion();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape()
{
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() {
std::vector<int64_t> shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()};
ge::Shape softmaxShape(shape_vec);
uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true);
AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData);
}
void MoeGatingTopKTilingRegbase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
if (perCoreRows == 0) {
OP_LOGE(context_, "perCoreRows can't be 0.");
return;
}
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = 0;
if (kGroup_ > MRGSORT_SIZE) {
int64_t index = MRGSORT_SIZE;
while (index < kGroup_) {
index = index * MRGSORT_SIZE;
vmsCount++;
}
}
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling()
{
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize()
{
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const
{
return MOE_GATING_TOP_K_REGBASE_TILING_KEY;
}
void MoeGatingTopKTilingRegbase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000);
} // namespace optiling

View File

@@ -0,0 +1,38 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_base.cpp
* \brief
*/
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "error_log.h"
#include "kernel_tiling/kernel_tiling.h"
namespace optiling {
static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context)
{
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
}
static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(MoeGatingTopK)
.Tiling(TilingForMoeGatingTopK)
.TilingParse<MoeGatingTopKCompileInfo>(TilingPrepareForMoeGatingTopK);
} // namespace optiling

View File

@@ -0,0 +1,89 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file common.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_COMMON_H
#define MOE_GATING_TOP_K_COMMON_H
#include "kernel_operator.h"
namespace MoeGatingTopK {
using namespace AscendC;
const float MIN_FP32 = *(float *)(&F32_NEG_INF);
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t REPEAT_BYTES = 256;
constexpr int64_t REPEAT_BLOCKS = 8;
constexpr int32_t CONSTANT_TWO = 2;
constexpr int32_t CONSTANT_THREE = 3;
constexpr int32_t CONSTANT_FOUR = 4;
constexpr int32_t CONSTANT_EIGHT = 8;
constexpr int64_t MERGE_LIST_TWO = 2;
constexpr int64_t MERGE_LIST_THREE = 3;
constexpr int64_t MERGE_LIST_FOUR = 4;
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
constexpr int64_t NORM_TYPE_SOFTMAX = 0;
constexpr int64_t NORM_TYPE_SIGMOID = 1;
__aicore__ inline int64_t Ceil(int64_t a, int64_t b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes)
{
if (bytes == 0) {
return 0;
}
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
}
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes)
{
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
}
template <typename T>
__aicore__ inline T Min(T a, T b)
{
return a > b ? b : a;
}
template <typename T>
__aicore__ inline T Max(T a, T b)
{
return a < b ? b : a;
}
template <typename T1, typename T2>
__aicore__ inline T1 CeilDiv(T1 x, T2 y)
{
if (y != 0 && x != 0) {
const T1 quotient = x / y;
return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient;
}
return x;
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_COMMON_H

View File

@@ -0,0 +1,55 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,63 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k.cpp
* \brief
*/
#include "moe_gating_top_k_e_k_fullload.h"
#include "moe_gating_top_k_without_group.h"
#include "moe_gating_top_k_generalized.h"
#include "error_log.h"
#define TILING_KEY_PER_GROUP_COUNT_32 0
#define TILING_KEY_WITHOUT_GROUP 1
#define TILING_KEY_GENERALIZED 2
using namespace AscendC;
using namespace MoeGatingTopK;
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
if (g_coreType == AIC) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling);
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
const MoeGatingTopKTilingData *__restrict t = &tilingData;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) {
MoeGatingTopKEKFullload<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) {
MoeGatingTopKWithoutGroup<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) {
MoeGatingTopKGenerlized<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
}
}

View File

@@ -0,0 +1,46 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_apt.cpp
* \brief
*/
#include "arch35/moe_gating_top_k_regbase.h"
using namespace AscendC;
using namespace MoeGatingTopK;
#define TILING_KEY_REGBASE 10000
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
if (g_coreType == AIC) {
return;
}
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling);
const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_REGBASE)) {
MoeGatingTopKRegbase<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe);
op.Process();
}
}

View File

@@ -0,0 +1,404 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_e_k_fullload.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H
#define MOE_GATING_TOP_K_E_K_FULLLOAD_H
#include "kernel_operator.h"
#include "common.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKEKFullload {
public:
__aicore__ inline MoeGatingTopKEKFullload(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBias();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t progress);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TBuf<TPosition::VECCALC> biasInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TQue<QuePosition::VECOUT, 1> xBiasQueue_;
TQue<QuePosition::VECOUT, 1> xSigmoidQueue_;
TQue<QuePosition::VECIN, 1> sigmoidTmpQueue_;
TQue<QuePosition::VECIN, 1> sortedInGroupQueue_;
TQue<QuePosition::VECIN, 1> sortedGroupQueue_;
TBuf<TPosition::VECCALC> calcTmpBuffer_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<T> outGm_;
int64_t blockIdx_;
int64_t perCoreRowCount_;
int64_t curCoreRowCount_;
int64_t expertCount_;
bool addBias_;
int64_t k_;
int64_t kGroup_;
int64_t groupCount_;
int64_t groupSelectMode_;
int64_t renorm_;
int64_t normType_;
int64_t outFlag_;
float routedScalingFactor_;
float eps_;
int64_t expertCountAlign_;
int64_t kAlign_;
int64_t perGroupExpertCount_;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInBias()
{
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE, expertCount_);
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::ComputeX()
{
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.AllocTensor<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xBiasTensor = xBiasQueue_.AllocTensor<float>();
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
LocalTensor<uint8_t> sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor<uint8_t>(); // 临时空间可以复用
Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_);
PipeBarrier<PIPE_V>();
if (addBias_) {
Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_);
} else {
Adds(xBiasTensor, xSigmoidTensor, static_cast<float>(0), expertCount_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
xBiasQueue_.EnQue<float>(xBiasTensor);
xInQueue_.FreeTensor(xInLocalTensor);
sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SortInGroup()
{
LocalTensor<float> xBiasTensor = xBiasQueue_.DeQue<float>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.AllocTensor<float>(); // 组内排序的结果, 后续归并需要
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>(); // 用于存储排序时的索引
ArithProgression(indexTensor.ReinterpretCast<int32_t>(), 0, 1, expertCount_); // 生成组索引0 1 2 ......
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
xBiasQueue_.FreeTensor(xBiasTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>();
LocalTensor<float> top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor<float>(); // 这个临时空间可以复用
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
indexTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
indexTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = 8;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = 8;
gatherMaskParams.src1RepeatStride = 0;
GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast<uint32_t>(64),
gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
LocalTensor<float> groupTop2SumTensor = top2ValueInGroupTensor;
PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1,
1); // 计算每个组内最大的两个数之和
PipeBarrier<PIPE_V>();
LocalTensor<uint32_t> groupIndexTensor = indexTensor;
ArithProgression(groupIndexTensor.ReinterpretCast<int32_t>(), 0, 1, groupCount_); // 生成组索引
PipeBarrier<PIPE_V>();
// 用最小值补到32个数
int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << groupCount_;
uint64_t mask[2] = {mask0, 0};
Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
// 排序将kgroup选出来
LocalTensor<float> sortedGroupTensor = sortedGroupQueue_.AllocTensor<float>();
Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1);
PipeBarrier<PIPE_V>();
LocalTensor<int32_t> sortedGroupIndexTensor = indexTensor.ReinterpretCast<int32_t>();
// 提取组序号
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
// 需要将组排序(这里是降序所以下mrgsor的时候反着取3、2、1、0)
Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_);
PipeBarrier<PIPE_V>();
duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << kGroup_;
uint64_t mask[2] = {mask0, 0};
Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast<uint32_t>(), 1);
PipeBarrier<PIPE_V>();
src1Pattern = 1;
GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast<uint32_t>(0), {1, 1, 0, 0},
rsvdCnt);
PipeBarrier<PIPE_V>();
Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_);
sortedGroupQueue_.FreeTensor(sortedGroupTensor);
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKGroupIndexTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<float> sortedExpertTensor = xInQueue_.AllocTensor<float>();
AscendC::MrgSort4Info params;
params.elementLengths[0] = k_;
params.elementLengths[1] = k_;
params.elementLengths[2] = k_;
params.elementLengths[3] = k_;
params.ifExhaustedSuspension = true;
params.validBit = 0b1111;
params.repeatTimes = 1;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2;
int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2;
int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2;
int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = sortedInGroupTensor[listOffset1];
srcList.src2 = sortedInGroupTensor[listOffset2];
srcList.src3 = sortedInGroupTensor[listOffset3];
srcList.src4 = sortedInGroupTensor[listOffset4];
MrgSort<float>(sortedExpertTensor, srcList, params);
PipeBarrier<PIPE_V>();
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
xInQueue_.FreeTensor(sortedExpertTensor);
expertIdxOutQueue_.EnQue(expertIdxTensor);
sortedInGroupQueue_.FreeTensor(sortedInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertScore()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<int32_t> expertByteIdxTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
LocalTensor<T> yTensor = yOutQueue_.AllocTensor<T>();
LocalTensor<float> yOutTensor;
if constexpr (!IsSameType<T, float>::value) {
yOutTensor = yTensor.template ReinterpretCast<float>()[kAlign_];
} else {
yOutTensor = yTensor;
}
Muls(expertByteIdxTensor, expertIdxTensor, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast<uint32_t>(),
static_cast<uint32_t>(0), k_);
LocalTensor<float> calTensor = calcTmpBuffer_.Get<float>();
PipeBarrier<PIPE_V>();
ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = calTensor.GetValue(0) + eps_;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(calTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, calTensor, k_);
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxTensor);
yOutQueue_.EnQue(yTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams);
xSigmoidQueue_.FreeTensor(xSigmoidTensor);
expertIdxOutQueue_.FreeTensor(expertIdxTensor);
yOutQueue_.FreeTensor(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
routedScalingFactor_ = tilingData_->routedScalingFactor;
eps_ = tilingData_->eps;
expertCountAlign_ = Align(expertCount_, sizeof(float));
kAlign_ = Align(expertCount_, sizeof(float));
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t)));
pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2);
pipe_->InitBuffer(sortedGroupQueue_, 2,
(groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM *
sizeof(float) * 2);
pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Process()
{
CopyInBias();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H

View File

@@ -0,0 +1,669 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_generalized.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H
#define MOE_GATING_TOP_K_E_K_GENERALIZED_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKGenerlized {
public:
__aicore__ inline MoeGatingTopKGenerlized(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CumputeActualTopKExpertId();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // Store input bias
TBuf<TPosition::VECCALC> expertIdBuf_; // Expert ID
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // Store value after adding bias
TBuf<TPosition::VECCALC> xNormBuf_; // Store value after computing sigmoid or softmax
TBuf<TPosition::VECCALC> sortedInGroupBuf_; // Store sorted results within groups
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> sortedGroupIndexBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t groupCountAlign_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t perGroupExpertCountAlign_ = 0;
int64_t groupSelectMode_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t outFlag_ = 0;
int64_t expertCountAlign_ = 0;
int64_t kAlign_ = 0;
bool isAlign_ = false;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (!isAlign_) {
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(biasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCountAlign_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xInLocalTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
(perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
else if (normType_ == 0) { // softmax
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex],
FLOAT32_NEG_INF, // MIN_FP32,
mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{
static_cast<uint16_t>(groupCount_), static_cast<uint32_t>(perGroupExpertCount_ * sizeof(float)),
static_cast<uint32_t>((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SortInGroup()
{
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> tmpLocal = calcTmpBuf_.Get<float>();
if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) {
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_);
} else {
for (int64_t group = 0; group < groupCount_; group++) {
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO],
xNormWithBiasTensor[group * perGroupExpertCountAlign_],
expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal,
perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM);
}
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, 0);
LocalTensor<uint32_t> maskTensor =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float));
LocalTensor<float> topValueInGroupTensor =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float));
LocalTensor<uint32_t> groupIndex =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float));
LocalTensor<float> sortedTopValue =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float));
LocalTensor<float> sortTmp =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float));
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
PipeBarrier<PIPE_V>();
if (groupSelectMode_ == 1) { // top2 sum
// Extract the first two elements of each group
maskTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride =
Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// Calculate the sum of the first two numbers in each group
PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor,
Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1,
CONSTANT_EIGHT); // Calculate the sum of the two largest numbers in each group
} else {
maskTensor.SetValue(0, static_cast<uint32_t>(1)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
}
PipeBarrier<PIPE_V>();
// Generate group indices
ArithProgression(groupIndex.ReinterpretCast<int32_t>(), static_cast<int32_t>(0), static_cast<int32_t>(1),
groupCount_); // Generate group indices
PipeBarrier<PIPE_V>();
int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = groupCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(topValueInGroupTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1,
REPEAT_BLOCKS);
PipeBarrier<PIPE_V>();
}
PipeBarrier<PIPE_V>();
// Sort
Sort<float, true>(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32));
PipeBarrier<PIPE_V>();
// Extract group indices
uint8_t src1Pattern = 2; // Built-in fixed pattern
GatherMask(groupIndex, sortedTopValue.template ReinterpretCast<uint32_t>(), src1Pattern, false,
static_cast<uint32_t>(0),
{1, static_cast<uint8_t>(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
duplicateIndex = kGroup_ - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(groupIndex.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS);
}
// Sort the selected group indices in descending order
LocalTensor<float> sortedGroupIndex = sortedGroupIndexBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedGroupIndex, groupIndex.ReinterpretCast<float>(), groupIndex, sortTmp, Ceil(kGroup_, 32));
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertIdx()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<int32_t> sortedGroupIndex = sortedGroupIndexBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> mrgSort0Tensor = calcTmpBuf_.Get<float>();
uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0};
uint16_t lenArr[CONSTANT_FOUR] = {
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_),
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_)};
MrgSort4Info params{lenArr, false, 0b1111, 1};
MrgSortSrcList<float> srcList;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) {
int64_t mrgLen = Min(i + 1, CONSTANT_FOUR);
if (mrgLen > 1) {
if (mrgLen == MERGE_LIST_FOUR) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2;
} else if (mrgLen == MERGE_LIST_THREE) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = 0;
offset[3] = 0;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
srcList.src1 = sortedInGroupTensor[offset[0]];
srcList.src2 = sortedInGroupTensor[offset[1]];
srcList.src3 = sortedInGroupTensor[offset[2]];
srcList.src4 = sortedInGroupTensor[offset[3]];
PipeBarrier<PIPE_V>();
MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params);
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
PipeBarrier<PIPE_V>();
DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]],
perGroupExpertCountAlign_ * 2);
}
}
int32_t baseLoop = 4;
LocalTensor<float> srcTensor = mrgSort0Tensor;
LocalTensor<float> dstTensor = mrgSort0Tensor;
for (int i = 0; i < tilingData_->vmsCount; i++) {
if (i % 2 == 0) {
srcTensor = mrgSort0Tensor;
dstTensor = sortedInGroupTensor;
} else {
srcTensor = sortedInGroupTensor;
dstTensor = mrgSort0Tensor;
}
int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR;
int32_t quotient = kGroup_ / nextBaseRow;
int32_t remainder = kGroup_ - quotient * nextBaseRow;
if (quotient > 0) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
params.validBit = 0b1111;
params.repeatTimes = 1;
for (int j = 0; j < quotient; j++) {
srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j];
srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)];
srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)];
srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)];
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params);
}
}
if (remainder > 0) {
int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2;
int32_t mrgLen = CeilDiv(remainder, baseLoop);
int32_t tailRow = remainder - (mrgLen - 1) * baseLoop;
if (mrgLen > 1) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.repeatTimes = 1;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
srcList.src1 = srcTensor[baseOffset];
srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2];
if (mrgLen == MERGE_LIST_FOUR) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3];
params.elementLengths[3] = perGroupExpertCount_ * tailRow;
params.validBit = 0b1111;
} else if (mrgLen == MERGE_LIST_THREE) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
params.elementLengths[2] = perGroupExpertCount_ * tailRow;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
params.elementLengths[1] = perGroupExpertCount_ * tailRow;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[baseOffset], srcList, params);
} else {
PipeBarrier<PIPE_V>();
DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2);
}
}
baseLoop = nextBaseRow;
}
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
uint8_t src1Pattern = 2; // Built-in fixed pattern
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, dstTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm
if (needRenorm) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[32];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CumputeActualTopKExpertId()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> topKExpertIdFp32 = calcTmpBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_);
PipeBarrier<PIPE_V>();
Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_);
PipeBarrier<PIPE_V>();
Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_);
PipeBarrier<PIPE_V>();
Muls(expertIdxOut, expertIdxOut, static_cast<int32_t>(perGroupExpertCountAlign_ - perGroupExpertCount_), k_);
PipeBarrier<PIPE_V>();
Sub(expertIdxOut, topKExpertId, expertIdxOut, k_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
groupSelectMode_ = tilingData_->groupSelectMode;
expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float));
kAlign_ = Align(k_, sizeof(float));
isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t)));
pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO);
pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (tilingData_->outFlag) {
CopuOutXNorm(row);
}
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CumputeActualTopKExpertId();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H

View File

@@ -0,0 +1,338 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_without_group.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKWithoutGroup {
public:
__aicore__ inline MoeGatingTopKWithoutGroup(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // Store input bias
TBuf<TPosition::VECCALC> expertIdBuf_; // Expert ID
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // Store value after adding bias
TBuf<TPosition::VECCALC> xNormBuf_; // Store value after computing sigmoid or softmax
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
bool outFlag_ = false;
int64_t k_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t expertCountAlign_ = 0;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCount_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_);
PipeBarrier<PIPE_V>();
} else if (normType_ == 0) { // sigmoid
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[8];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCount_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = expertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1);
PipeBarrier<PIPE_V>();
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(float)), 0, 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> sortedScore = calcTmpBuf_.Get<float>();
LocalTensor<float> sortTmp = calcTmpBuf_.Get<float>()[expertCountAlign_ * CONSTANT_TWO];
PipeBarrier<PIPE_ALL>();
Sort<float, true>(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp,
expertCountAlign_ / ONE_REPEAT_SORT_NUM);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
uint8_t src1Pattern = 2; // Built-in fixed pattern
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, sortedScore.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
DataCopy(expertIdxOut, topKExpertId, expertCountAlign_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm
if (needRenorm == 1) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
outFlag_ = tilingData_->outFlag == 1;
k_ = tilingData_->k;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
// init calc buf
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t));
// init tmp buf
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (outFlag_) {
CopuOutXNorm(row);
}
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H

View File

@@ -0,0 +1,51 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling.h
* \brief
*/
#pragma once
#include <vector>
#include <graph/tensor.h>
#include "data_copy_transpose_tiling_def.h"
namespace optiling {
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
optiling::CopyTransposeTiling &tiling)
{
constexpr int64_t B_INDEX = 0;
constexpr int64_t N_INDEX = 1;
constexpr int64_t S_INDEX = 2;
constexpr int64_t H_INDEX = 3;
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
}
} // namespace optiling

View File

@@ -0,0 +1,43 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling_def.h
* \brief
*/
#pragma once
#include <cstdint>
#include <register/tilingdata_base.h>
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
} // namespace optiling

View File

@@ -0,0 +1,56 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// Modify OP_TILING_CHECK macro to ensure proper handling of expressions
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,256 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_base.h
* \brief
*/
#pragma once
#include <sstream>
#include <exe_graph/runtime/tiling_context.h>
#include <graph/utils/type_utils.h>
#include "tiling/platform/platform_ascendc.h"
#include "error_log.h"
#ifdef ASCENDC_OP_TEST
#define ASCENDC_EXTERN_C extern "C"
#else
#define ASCENDC_EXTERN_C
#endif
namespace Ops {
namespace Transformer {
namespace OpTiling {
struct AiCoreParams {
uint64_t ubSize = 0;
uint64_t blockDim = 0;
uint64_t aicNum = 0;
uint64_t l1Size = 0;
uint64_t l0aSize = 0;
uint64_t l0bSize = 0;
uint64_t l0cSize = 0;
};
struct CompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
struct FlashAttentionScoreGradCompileInfo {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
platform_ascendc::SocVersion socVersion;
};
struct FACompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
class TilingBaseClass {
public:
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
{}
virtual ~TilingBaseClass() = default;
// Tiling execution framework
// 1. GRAPH_SUCCESS: Success, and no need to continue executing subsequent Tiling class implementations
// 2. GRAPH_FAILED: Failure, abort the entire Tiling process
// 3. GRAPH_PARAM_INVALID: This class does not support, need to continue executing other Tiling class implementations
ge::graphStatus DoTiling()
{
auto ret = GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (!IsCapable()) {
return ge::GRAPH_PARAM_INVALID;
}
ret = DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoLibApiTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context_->SetTilingKey(GetTilingKey());
DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
// Update context
virtual void Reset(gert::TilingContext* context)
{
context_ = context;
}
protected:
virtual bool IsCapable() = 0;
// 1. Get platform information such as CoreNum, UB/L1/L0C resource sizes
virtual ge::graphStatus GetPlatformInfo() = 0;
// 2. Get INPUT/OUTPUT/ATTR information
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
// 3. Calculate data splitting TilingData
virtual ge::graphStatus DoOpTiling() = 0;
// 4. Calculate high-level API TilingData
virtual ge::graphStatus DoLibApiTiling() = 0;
// 5. Calculate TilingKey
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
// 6. Calculate Workspace size
virtual ge::graphStatus GetWorkspaceSize() = 0;
// 7. Save Tiling data
virtual ge::graphStatus PostTiling() = 0;
// 8. Dump Tiling data
virtual void DumpTilingInfo()
{
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
if (enable != 1) {
return;
}
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
auto bufLen = context_->GetRawTilingData()->GetDataSize();
std::ostringstream oss;
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
<< ", content:";
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
oss << *(buf + i) << ",";
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
OP_LOGD(context_, "%s", oss.str().c_str());
oss.str("");
}
}
OP_LOGD(context_, "%s", oss.str().c_str());
}
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
{
uint32_t ration;
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
return sliceNum;
}
ration = aivCoreNum / aicCoreNum;
return (sliceNum + (ration - 1)) / ration;
}
template <typename T>
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
[[nodiscard]] std::string GetTensorDebugStr(
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
{
if (shape == nullptr || tensor == nullptr) {
return "nil ";
}
std::ostringstream oss;
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
oss << "(format: "
<< ge::TypeUtils::FormatToSerialString(
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
<< "),";
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
return oss.str();
}
[[nodiscard]] std::string GetTilingContextDebugStr()
{
std::ostringstream oss;
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
oss << "input" << i << ": ";
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
}
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
oss << "output" << i << ": ";
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
}
return oss.str();
}
[[nodiscard]] std::string GetTilingDataDebugStr() const
{
auto rawTilingData = context_->GetRawTilingData();
auto rawTilingDataSize = rawTilingData->GetDataSize();
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
size_t len = rawTilingDataSize / sizeof(int32_t);
std::ostringstream oss;
for (size_t i = 0; i < len; i++) {
oss << data[i] << ", ";
}
return oss.str();
}
protected:
gert::TilingContext* context_ = nullptr;
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -0,0 +1,63 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_key.h
* \brief
*/
#pragma once
#include <cstdint>
namespace Ops {
namespace Transformer {
namespace OpTiling {
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr uint64_t kBase = 10; // Base-10 carry base
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
}
// TilingKey generation rules:
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
// For other specialized scenarios, define your own bit fields and values
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace Optiling
} // namespace Transformer
} // namespace Ops

View File

@@ -0,0 +1,351 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_templates_registry.h
* \brief
*/
#pragma once
#include <map>
#include <string>
#include <memory>
#include "exe_graph/runtime/tiling_context.h"
#include "tiling_base.h"
#include "error_log.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
template <typename T>
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
{
return std::unique_ptr<T>(new (std::nothrow) T(context));
}
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
class TilingCases {
public:
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
void AddTiling(int32_t priority)
{
OP_CHECK_IF(
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
cases_[priority] = TILING_CLASS<T>;
OP_CHECK_IF(
cases_[priority] == nullptr,
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
}
const std::map<int32_t, TilingClassCase>& GetTilingCases()
{
return cases_;
}
private:
std::map<int32_t, TilingClassCase> cases_;
const std::string op_type_;
};
// --------------------------------Interfacce with soc version --------------------------------
class TilingRegistryNew {
public:
TilingRegistryNew() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistryNew& GetInstance();
#else
static TilingRegistryNew& GetInstance()
{
static TilingRegistryNew registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
if (soc_iter == registry_map_.end()) {
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
registry_map_[soc_version] = op_type_map;
} else {
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
}
OP_CHECK_IF(
registry_map_[soc_version][op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[soc_version][op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
const char* op_type = context->GetNodeType();
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
return ge::GRAPH_FAILED;
}
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
int32_t soc_version;
const char* op_type = context->GetNodeType();
auto platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto priority_id : priorities) {
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
auto templateFunc = tilingCaseIter->second(context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
}
}
}
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
OP_CHECK_IF(
soc_iter == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
return empty_tiling_case_);
auto op_iter = soc_iter->second.find(op_type);
OP_CHECK_IF(
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
return empty_tiling_case_);
return op_iter->second->GetTilingCases();
}
private:
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
};
class RegisterNew {
public:
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
RegisterNew& tiling(int32_t priority, int32_t soc_version)
{
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
template <typename T>
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
{
for (int32_t soc_version : soc_versions) {
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
return *this);
tilingCases->AddTiling<T>(priority);
}
return *this;
}
private:
const std::string op_type_;
};
// --------------------------------Interfacce without soc version --------------------------------
class TilingRegistry {
public:
TilingRegistry() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistry& GetInstance();
#else
static TilingRegistry& GetInstance()
{
static TilingRegistry registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
{
if (registry_map_.find(op_type) == registry_map_.end()) {
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
OP_CHECK_IF(
registry_map_[op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto priorityId : priorities) {
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
return status;
}
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do op tiling failed");
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
{
OP_CHECK_IF(
registry_map_.find(op_type) == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
return registry_map_[op_type]->GetTilingCases();
}
private:
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
};
class Register {
public:
explicit Register(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
Register& tiling(int32_t priority)
{
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
private:
const std::string op_type_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops
// op_type: operator name, class_name: registered tiling class, soc_version: chip version number
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
// op_type: operator name, class_name: registered tiling class
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
// op_type: operator name, class_name: registered tiling class
// soc_version: SOC version, used to distinguish different SOCs
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
// op_type: operator name, class_name: registered tiling class
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register \
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)

View File

@@ -0,0 +1,139 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_type.h
* \brief
*/
#pragma once
#include <cstdint>
namespace optiling {
enum class AxisEnum {
B = 0,
N2 = 1,
G = 2,
S1 = 3,
S2 = 4,
D = 5,
NONE = 9,
};
enum class DtypeEnum {
FLOAT16 = 0,
FLOAT32 = 1,
BFLOAT16 = 2,
FLOAT16_PRECISION = 3,
};
enum class PerformanceOrientedEnum {
BIG_BUFFER = 1,
BIG_DOUBLE_BUFFER = 2,
};
enum class MatmulConfig {
NULL_CONFIG = 0,
NORMAL_CONFIG = 1,
MDL_CONFIG = 2
};
enum class PseConfig {
NO_PSE = 0,
EXIST_PSE = 1
};
enum class AttenMaskConfig {
NO_ATTEN_MASK = 0,
EXIST_ATTEN_MASK = 1
};
enum class DropOutConfig {
NO_DROP_OUT = 0,
EXIST_DROP_OUT = 1
};
enum class CubeFormatEnum {
ND = 0,
NZ = 1
};
enum class LayoutEnum {
BSND = 0,
SBND = 1,
BNSD = 2,
TND = 3,
NTD_TND = 4
};
enum class CubeInputSourceEnum {
GM = 0,
L1 = 1
};
enum class OptionEnum {
DISABLE = 0,
ENABLE = 1
};
enum class SparseEnum {
ALL = 0,
NONE = 1,
ANY = 2,
CAUSAL = 3,
BAND = 4,
PREFIX = 5,
BAND_COMPRESS = 6,
RIGHT_DOWN_CAUSAL = 7,
RIGHT_DOWN_CAUSAL_BAND = 8,
BAND_LEFT_UP_CAUSAL = 9
};
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr int64_t base10Multiplier = 10;
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
}
// TilingKey generation rules:
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
// For other specialized scenarios, define your own bit fields and values
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace optiling

View File

@@ -0,0 +1,30 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_util.h
* \brief
*/
#pragma once
#include "register/op_impl_registry.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
bool IsRegbaseSocVersion(const gert::TilingContext* context);
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -1219,10 +1219,65 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
const at::Tensor& x,
int64_t k,
int64_t k_group,
int64_t group_count,
int64_t group_select_mode,
int64_t renorm,
int64_t norm_type,
bool out_flag,
double routed_scaling_factor,
double eps,
const c10::optional<at::Tensor>& bias_opt
)
{
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
auto x_size = x.sizes();
auto rows = x_size[0];
auto expert_num = x_size[1];
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
if (bias.defined()) {
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
auto bias_size = bias.sizes();
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
}
at::Tensor y = at::empty({rows, k}, x.options());
at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt));
at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(aclnnMoeGatingTopK,
x,
bias,
k,
k_group,
group_count,
group_select_mode,
renorm,
norm_type,
out_flag,
routed_scaling_factor,
eps,
y,
expert_idx,
out
);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
{
// vLLM-Ascend custom ops
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
@@ -1358,6 +1413,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"num_ranks) -> Tensor");
ops.impl("combine_prefill", torch::kPrivateUse1,
&vllm_ascend::combine_prefill);
ops.def(
"npu_moe_init_routing_custom(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, "
" int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, "
@@ -1365,4 +1421,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)"
);
ops.impl("npu_moe_init_routing_custom", torch::kPrivateUse1, &vllm_ascend::npu_moe_init_routing_custom);
// vLLM-Ascend custom ops
ops.def(
"moe_gating_top_k(Tensor x, "
"int k, "
"int k_group, "
"int group_count, "
"int group_select_mode, "
"int renorm, "
"int norm_type, "
"bool out_flag, "
"float routed_scaling_factor, "
"float eps,"
"Tensor? bias_opt=None)"
"-> (Tensor y ,Tensor expert_idx, Tensor out)"
);
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
}

View File

@@ -366,7 +366,43 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat));
return {expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale};
}
std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
const at::Tensor& x,
int64_t k,
int64_t k_group,
int64_t group_count,
int64_t group_select_mode,
int64_t renorm,
int64_t norm_type,
bool out_flag,
double routed_scaling_factor,
double eps,
const c10::optional<at::Tensor>& bias_opt
)
{
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
auto x_size = x.sizes();
auto rows = x_size[0];
auto expert_num = x_size[1];
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
if (bias.defined()) {
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
auto bias_size = bias.sizes();
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
}
at::Tensor y = at::empty({rows, k}, x.options());
at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt));
at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
}
} // namespace meta
} // namespace vllm_ascend
@@ -374,6 +410,7 @@ namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
@@ -402,5 +439,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
// moe_init_routing_custom
ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta);
// Moe_gating_top_k
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
}
}

View File

@@ -305,8 +305,8 @@ def test_select_experts(
)
call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, topk_group, num_expert_group, scoring_func,
custom_routing_function)
hidden_states, topk, renormalize, topk_group, num_expert_group,
scoring_func, custom_routing_function)
if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:

View File

@@ -0,0 +1,210 @@
import random
import numpy
import pytest
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Fix random seed to ensure test reproducibility
RTOL_TOLERANCE = 1e-5
ATOL_TOLERANCE = 1e-8
seed = 45
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def softmax_func(x, axis=None):
"""Softmax implementation (adapted for numpy calculation)"""
if "float16" in x.dtype.name:
x = x.astype(numpy.float32)
x_max = x.max(axis=axis, keepdims=True)
x_sub = x - x_max
y = numpy.exp(x_sub)
x_sum = y.sum(axis=axis, keepdims=True)
res = y / x_sum
return res, x_max, x_sum
def moe_gating_top_k_numpy_ref(x: torch.Tensor,
k: int,
bias: torch.Tensor | None,
k_group: int = 1,
group_count: int = 1,
group_select_mode: int = 0,
renorm: int = 0,
norm_type: int = 0,
y2_flag: bool = False,
routed_scaling_factor: float = 1.0,
eps: float = 1e-20) -> tuple:
"""NumPy reference implementation of MOE Gating TopK.
For result comparison with NPU operator, ensure the consistency
between NPU kernel and baseline implementation.
Args:
x: Input tensor of shape (num_tokens, num_experts)
k: Number of top-k experts to select
bias: Bias tensor of shape (num_experts,) (optional)
k_group: Number of top-k groups to select
group_count: Number of expert groups
group_select_mode: Group selection mode (0: max, 1: top2 sum)
renorm: Whether to renormalize the output (0/1)
norm_type: Normalization type (0: softmax, 1: sigmoid)
y2_flag: Whether to output original x as y2
routed_scaling_factor: Scaling factor for routing weights
eps: Small epsilon to avoid division by zero
Returns:
tuple: (y, indices, y2)
- y: Top-k weights of shape (num_tokens, k)
- indices: Top-k expert indices of shape (num_tokens, k)
- y2: Original x if y2_flag is True, else None
"""
dtype = x.dtype
if dtype != torch.float32:
x = x.to(dtype=torch.float32)
if bias is not None:
bias = bias.to(dtype=torch.float32)
x = x.numpy()
if bias is not None:
bias = bias.numpy()
if norm_type == 0: # softmax normalization
x, _, _ = softmax_func(x, -1)
else: # sigmoid normalization
x = 1 / (1 + numpy.exp(-x))
original_x = x
if bias is not None:
x = x + bias
if group_count > 1:
x = x.reshape(x.shape[0], group_count, -1)
if group_select_mode == 0:
group_x = numpy.amax(x, axis=-1)
else:
group_x = numpy.partition(x, -2, axis=-1)[..., -2:].sum(axis=-1)
indices = numpy.argsort(-group_x, axis=-1, kind='stable')[:, :k_group]
mask = numpy.ones((x.shape[0], group_count), dtype=bool)
mask[numpy.arange(x.shape[0])[:, None], indices] = False
x = numpy.where(mask[..., None], float('-inf'), x)
x = x.reshape(x.shape[0], -1)
_, indices = torch.sort(torch.from_numpy(x),
dim=-1,
stable=True,
descending=True)
indices = numpy.asarray(indices[:, :k])
y = numpy.take_along_axis(original_x, indices, axis=1)
if norm_type == 1 or renorm == 1:
y /= (numpy.sum(y, axis=-1, keepdims=True) + eps)
y *= routed_scaling_factor
y2 = original_x if y2_flag else None
y = torch.tensor(y, dtype=dtype)
return y, indices.astype(numpy.int32), y2
# pytest parameterized decorators (cover all test scenarios)
@pytest.mark.parametrize("group_select_mode", [0, 1])
@pytest.mark.parametrize("renorm", [1])
@pytest.mark.parametrize("norm_type", [0, 1])
@pytest.mark.parametrize("group_count", [1, 8])
@pytest.mark.parametrize("k_ranges", [4, 8, 12, 16, 6, 32])
@pytest.mark.parametrize("x_dim0_range", list(range(1, 17)))
@pytest.mark.parametrize("x_dim1_range", [256, 128, 64, 208, 192, 160])
def test_npu_moe_gating_topk_compare(group_select_mode: int,
renorm: int,
norm_type: int,
group_count: int,
k_ranges: int,
x_dim0_range: int,
x_dim1_range: int,
device: str = "npu"):
"""Ascend NPU MOE Gating TopK operator test.
Compare NPU kernel results with NumPy reference implementation
to verify the correctness of Ascend custom op.
Args:
group_select_mode: Group selection mode (0: max, 1: top2 sum)
renorm: Whether to renormalize output (fixed to 1 in test)
norm_type: Normalization type (0: softmax, 1: sigmoid)
group_count: Number of expert groups
k_ranges: Number of top-k experts to select
x_dim0_range: First dimension of input tensor (num_tokens)
x_dim1_range: Second dimension of input tensor (num_experts)
device: Target device (fixed to "npu" in test)
"""
# Simplify parameter names for better readability
k = k_ranges
dim0 = x_dim0_range
dim1 = x_dim1_range
# Skip invalid cases: k cannot exceed num_experts per group
if k > dim1 // group_count:
return
# Construct test inputs
x = numpy.random.uniform(-2, 2, (dim0, dim1)).astype(numpy.float32)
bias = numpy.random.uniform(-2, 2, (dim1, )).astype(numpy.float32)
x_tensor = torch.tensor(x, dtype=torch.float32)
bias_tensor = torch.tensor(bias, dtype=torch.float32)
# Fix k_group value to avoid irreproducibility caused by random.randint
k_group = min(1, group_count)
out_flag = False
routed_scaling_factor = 1.0
eps = 1e-20
# Calculate NumPy reference results
y, expert_idx, out = moe_gating_top_k_numpy_ref(
x_tensor,
k=k,
bias=bias_tensor,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=renorm,
norm_type=norm_type,
y2_flag=out_flag,
routed_scaling_factor=routed_scaling_factor,
eps=eps,
)
# Calculate NPU operator results
y_npu, expert_idx_npu, out_npu = torch.ops._C_ascend.moe_gating_top_k(
x_tensor.npu(),
k=k,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=renorm,
norm_type=norm_type,
out_flag=out_flag,
routed_scaling_factor=routed_scaling_factor,
eps=eps,
bias_opt=bias_tensor.npu() if bias_tensor is not None else None,
)
# Verify consistency between NPU and NumPy results
assert numpy.allclose(y.cpu().numpy(),
y_npu.cpu().numpy(),
rtol=RTOL_TOLERANCE,
atol=ATOL_TOLERANCE)
assert numpy.allclose(expert_idx,
expert_idx_npu.cpu().numpy(),
rtol=RTOL_TOLERANCE,
atol=ATOL_TOLERANCE)
if __name__ == "__main__":
# Execute pytest tests with verbose output
pytest.main([__file__, "-sv"])

View File

@@ -17,7 +17,6 @@
from typing import Callable, Optional
import torch
import torch_npu
from vllm_ascend.utils import get_weight_prefetch_method
@@ -64,6 +63,7 @@ def select_experts(hidden_states: torch.Tensor,
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
@@ -102,10 +102,13 @@ def select_experts(hidden_states: torch.Tensor,
def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None):
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
return False
if custom_routing_function is not None:
return False
if scoring_func != "softmax" and scoring_func != "sigmoid":
@@ -209,26 +212,25 @@ def _select_experts_with_fusion_ops(
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
renorm = int(renormalize)
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
router_logits,
k=top_k,
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
group_select_mode=1,
renorm=renorm,
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
out_flag=False,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
eps=float(1e-20),
bias_opt=e_score_correction_bias,
)
return topk_weights, topk_ids