Revert "moe_gating_top_k" (#5512)

Reverts vllm-project/vllm-ascend#5271

It breaks e2e test

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
This commit is contained in:
zzzzwwjj
2025-12-30 15:05:47 +08:00
committed by GitHub
parent 4ff4d1cef9
commit 71f729a661
34 changed files with 22 additions and 4791 deletions

View File

@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
@@ -70,7 +70,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"dispatch_layout"
"notify_dispatch"
"moe_init_routing_custom"
"moe_gating_top_k"
)
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
SOC_ARG="ascend910_93"

View File

@@ -1,43 +0,0 @@
# ----------------------------------------------------------------------------
# This program is free software, you can redistribute it and/or modify.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------
add_ops_compile_options(
OP_NAME MoeGatingTopK
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
# Host 侧算子实现aclnn
if (BUILD_OPEN_PROJECT)
target_sources(op_host_aclnn PRIVATE
moe_gating_top_k_def.cpp
)
# Tiling 模块
target_sources(optiling PRIVATE
moe_gating_top_k_tiling.cpp
moe_gating_top_k_tiling_base.cpp
moe_gating_top_k_tiling_arch35.cpp
)
target_sources(opsproto PRIVATE
moe_gating_top_k_proto.cpp
moe_gating_top_k_infershape.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
endif()

View File

@@ -1,56 +0,0 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -1,61 +0,0 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file math_util.h
* \brief
*/
#ifndef TILING_MATMUL_MATH_UTIL_H
#define TILING_MATMUL_MATH_UTIL_H
#include <array>
#include <cstdint>
#include <vector>
#include <utility>
namespace matmul_tiling {
class MathUtil {
public:
static bool IsEqual(float leftValue, float rightValue);
template<typename T>
static auto CeilDivision(T num1, T num2) -> T
{
if (num2 == 0) {
return 0;
}
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
static_cast<int64_t>(num2));
}
template<typename T>
static auto Align(T num1, T num2) -> T
{
return CeilDivision(num1, num2) * num2;
}
static int32_t AlignDown(int32_t num1, int32_t num2);
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
const int32_t factorEnd);
static bool CheckFactorNumSatisfy(const int32_t dim);
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
bool isKDim);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
const int32_t coreNum, const int32_t maxNum);
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
};
} // namespace matmul_tiling
#endif // _MATH_UTIL_H_

View File

@@ -1,71 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class MoeGatingTopK : public OpDef {
public:
explicit MoeGatingTopK(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("expert_idx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("k").Int();
this->Attr("k_group").AttrType(OPTIONAL).Int(1);
this->Attr("group_count").AttrType(OPTIONAL).Int(1);
this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0);
this->Attr("renorm").AttrType(OPTIONAL).Int(0);
this->Attr("norm_type").AttrType(OPTIONAL).Int(0);
this->Attr("out_flag").AttrType(OPTIONAL).Bool(false);
this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0);
this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
OpAICoreConfig regbaseCfg;
regbaseCfg.DynamicCompileStaticFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt");
this->AICore().AddConfig("ascend910_95", regbaseCfg);
}
};
OP_ADD(MoeGatingTopK);
} // namespace ops

View File

@@ -1,147 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_infershape.cpp
* \brief
*/
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "error_log.h"
#include <string>
#include <string>
#define TO_STRING(x) std::string(#x)
using namespace ge;
namespace ops {
static constexpr size_t DIM_ONE = 1;
static constexpr size_t DIM_TWO = 2;
static constexpr int64_t NEG_ONE = -1;
static constexpr int64_t X_INDEX = 0;
static constexpr int64_t BIAS_INDEX = 1;
static constexpr int64_t Y_INDEX = 0;
static constexpr int64_t EXPERT_IDX_INDEX = 1;
static constexpr int64_t OUT_INDEX = 2;
static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape)
{
int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
if (XRows < NEG_ONE || expertNum < NEG_ONE) {
OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape,
const int64_t k)
{
if (xShape->GetDimNum() == 1U) {
if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) {
OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
} else if (xShape->GetDimNum() != DIM_TWO) {
OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.",
TO_STRING(*xShape).c_str());
return ge::GRAPH_FAILED;
}
if (k < 0) {
OP_LOGE(context, "k must be a non-negative number.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k)
{
OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str());
OP_LOGD(context, "k is: %ld.", k);
}
static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape,
const gert::Shape *expertIdxShape, const gert::Shape *outShape)
{
OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str());
OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str());
OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str());
}
static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInfershape.");
// 获取输入shape
const gert::Shape *xShape = context->GetInputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
gert::Shape *yShape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
gert::Shape *expertIdxShape = context->GetOutputShape(1);
OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape);
gert::Shape *outShape = context->GetOutputShape(2);
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
// 获取attr
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, kPtr);
const int64_t k = *kPtr;
ShowInputShapeInfo(context, xShape, k);
// 参数校验
if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
yShape->SetDimNum(DIM_TWO);
yShape->SetDim(0U, rows);
yShape->SetDim(1U, k);
expertIdxShape->SetDimNum(DIM_TWO);
expertIdxShape->SetDim(0U, rows);
expertIdxShape->SetDim(1U, k);
outShape->SetDimNum(DIM_TWO);
outShape->SetDim(0U, rows);
outShape->SetDim(1U, expertNum);
ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape);
OP_LOGD(context, "End to do MoeGatingTopKInfershape.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context)
{
OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType.");
auto xDtype = context->GetInputDataType(0);
context->SetOutputDataType(Y_INDEX, xDtype);
context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32);
context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT);
OP_LOGD(context, "End to do MoeGatingTopKInferDataType.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK);
} // namespace ops

View File

@@ -1,15 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#include "moe_gating_top_k_proto.h"

View File

@@ -1,66 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_proto.h
* \brief
*/
#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
#include "graph/operator_reg.h"
namespace ge {
/**
* @brief Compute renorm(sigmoid) and topk for moe input.
*
* @par Inputs:
* @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count.
* @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x.
*
* @par Outputs:
* @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x.
The size of the non-1 axis must be the same as that of the corresponding axis of x.
The size of the -1 axis must be the same as that of k.
* @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y.
* @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x.
*
* @par Attributes:
* @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value.
* @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1.
* @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1.
* @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0.
* @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0.
* @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0.
* @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false.
* @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0.
* @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20.
*/
REG_OP(MoeGatingTopK)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
.OUTPUT(expert_idx, TensorType({DT_INT32}))
.OUTPUT(out, TensorType({DT_FLOAT}))
.REQUIRED_ATTR(k, Int)
.ATTR(k_group, Int, 1)
.ATTR(group_count, Int, 1)
.ATTR(group_select_mode, Int, 0)
.ATTR(renorm, Int, 0)
.ATTR(norm_type, Int, 0)
.ATTR(out_flag, Bool, false)
.ATTR(routed_scaling_factor, Float, 1.0)
.ATTR(eps, Float, 1e-20f)
.OP_END_FACTORY_REG(MoeGatingTopK)
} // namespace ge
#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_

View File

@@ -1,580 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling.cpp
* \brief
*/
#include <cmath>
#include "register/op_def_registry.h"
#include "exe_graph/runtime/infer_shape_context.h"
#include "register/op_impl_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "platform/platform_info.h"
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
// 放在文件顶部,或单独头文件中
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216; // 预留16M空间
const static uint32_t DATATYPESIZE_FLOAT = 4;
const static bool IS_LARGEST = true;
const static bool IS_INITINDEX = false;
const static bool IS_REUSESOURCE = false;
const static uint64_t WITH_GROUP_CONDITION = 1;
const static uint64_t WITHOUT_GROUP_CONDITION = 2;
const static uint64_t MAX_IN_GROUP_CONDITION = 3;
constexpr int32_t ROW_COUNT_PER_TASK = 1;
const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0;
const static uint64_t TILING_KEY_WITHOUT_GROUP = 1;
const static uint64_t TILING_KEY_GENERALIZED = 2;
inline static int64_t CeilLog4(int64_t x)
{
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(4))); // 4 for four
}
class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingBase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
return true;
}
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
ge::graphStatus GetPlatformInfo() override;
// 2、获取INPUT/OUTPUT/ATTR信息
ge::graphStatus GetShapeAttrsInfo() override;
// 3、计算数据切分TilingData
ge::graphStatus DoOpTiling() override;
// 4、计算高阶API的TilingData
ge::graphStatus DoLibApiTiling() override;
// 5、计算TilingKey
uint64_t GetTilingKey() const override;
// 6、计算Workspace 大小
ge::graphStatus GetWorkspaceSize() override;
// 7、保存Tiling数据
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void SplitRows();
void CalTmpBufUbSize();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_ = 0;
int64_t expertCount_ = 0;
int64_t addBias_ = 0;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKTilingData moeGatingTopKTilingData_;
};
ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
// 通过输入获取rows 和 expertCount
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckAttr()
{
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
perGroupExpertCount_ = expertCount_ / groupCount_;
OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_);
OP_CHECK_IF(perGroupExpertCount_ < 1,
OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
perGroupExpertCount_, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_,
OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_),
return ge::GRAPH_FAILED);
int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L);
OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign);
if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) {
// 分组场景下才需要校验对齐后的数量
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
// 获取输入shape信息
OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_);
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str());
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
if (biasShape_ != nullptr) {
OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str());
}
// 获取输出shape
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str());
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str());
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr);
outShape_ = &outPtr->GetStorageShape();
if (outShape_ != nullptr) {
OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str());
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str());
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str());
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str());
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc);
auto normOutDtype = normOutDesc->GetDataType();
OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.",
ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()),
return ge::GRAPH_FAILED);
// 获取属性
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
OP_LOGI(context_, "Attr k is: %ld", k_);
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
OP_LOGI(context_, "Attr k_group is: %ld", kGroup_);
moeGatingTopKTilingData_.set_kGroup(kGroup_);
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
OP_LOGI(context_, "Attr group_count is: %ld", groupCount_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_);
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
}
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
OP_LOGI(context_, "Attr renorm is: %ld", renorm_);
moeGatingTopKTilingData_.set_renorm(renorm_);
}
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
OP_LOGI(context_, "Attr norm_type is: %ld", normType_);
moeGatingTopKTilingData_.set_normType(normType_);
}
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_);
moeGatingTopKTilingData_.set_outFlag(outFlag_);
}
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_);
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
}
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
OP_LOGI(context_, "Attr eps is: %f", eps_);
moeGatingTopKTilingData_.set_eps(eps_);
}
OP_LOGI(context_, "Attr eps is: %f ", eps_);
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape()
{
OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str());
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingBase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
// perCoreRows cannot be 0
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L));
OP_LOGI(context_, "vms count is: %ld", vmsCount);
moeGatingTopKTilingData_.set_vmsCount(vmsCount); // 需要归并的轮数
}
void MoeGatingTopKTilingBase::CalTmpBufUbSize()
{
std::vector<int64_t> shape_vec = {expertCount_};
ge::Shape shape(shape_vec);
uint32_t maxValue = 0;
uint32_t minValue = 0;
AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue);
int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast<int64_t>(sizeof(float));
moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast<int64_t>(minValue)));
}
ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling()
{
OP_LOGI(context_, "DoOpTiling: start");
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize()
{
// 计算workspace大小
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingBase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingBase::GetTilingKey() const
{
// DeepSeekV3排序对齐高性能场景
if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ &&
groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID &&
!outFlag_) {
// DeepSeekV3排序对齐高性能场景
return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF;
} else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) {
/**
* 不分组场景:
* 1. 分组数为 1
* 2. 分组数等于专家数(每个组只有一个专家)
* 3. 选择所有组
*/
return TILING_KEY_WITHOUT_GROUP;
} else {
return TILING_KEY_GENERALIZED;
}
}
void MoeGatingTopKTilingBase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000);
} // namespace optiling

View File

@@ -1,86 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_tiling.h
* \brief
*/
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
#include <cmath>
#include <cstdint>
#include <vector>
#include <algorithm>
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "register/op_def_registry.h"
#include "register/op_impl_registry.h"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "platform/platform_infos_def.h"
#include "math_util.h"
//#include "util/extern_math_util.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData)
BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData)
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
TILING_DATA_FIELD_DEF(int64_t, rowCount);
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
TILING_DATA_FIELD_DEF(int64_t, expertCount);
TILING_DATA_FIELD_DEF(int64_t, addBias);
TILING_DATA_FIELD_DEF(int64_t, k);
TILING_DATA_FIELD_DEF(int64_t, kGroup);
TILING_DATA_FIELD_DEF(int64_t, groupCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
TILING_DATA_FIELD_DEF(int64_t, renorm);
TILING_DATA_FIELD_DEF(int64_t, normType);
TILING_DATA_FIELD_DEF(int64_t, outFlag);
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
TILING_DATA_FIELD_DEF(float, eps);
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData)
struct MoeGatingTopKCompileInfo {};
} // namespace optiling
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H

View File

@@ -1,521 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_arch35.cpp
* \brief
*/
#include "error_log.h"
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "platform/platform_info.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#ifndef CEIL_ALIGN
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
#endif
#ifndef CEIL_DIV
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
#endif
namespace optiling {
const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000;
const static int64_t GROUP_SELECT_MODE_MAX = 0;
const static int64_t GROUP_SELECT_MODE_SUM = 1;
const static int64_t RENORM_NO = 0;
const static int64_t RENORM_L1 = 1;
const static int64_t NORM_TYPE_SOFTMAX = 0;
const static int64_t NORM_TYPE_SIGMOID = 1;
const static int64_t OUT_FLAG_FALSE = 0;
const static int64_t OUT_FLAG_TRUE = 1;
const static size_t X_INPUT_DIMS = 2;
const static size_t BIAS_INPUT_DIMS = 1;
const static size_t Y_OUTPUT_DIMS = 2;
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
const static size_t OUT_OUTPUT_DIMS = 2;
const static int64_t MAX_EXPERT_COUNT = 2048;
const static int64_t X_INPUT_INDEX = 0;
const static int64_t BIAS_INPUT_INDEX = 1;
const static int64_t Y_OUTPUT_INDEX = 0;
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
const static int64_t OUT_OUTPUT_INDEX = 2;
const static int64_t K_ATTR_INDEX = 0;
const static int64_t K_GROUP_ATTR_INDEX = 1;
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
const static int64_t RENORM_ATTR_INDEX = 4;
const static int64_t MRGSORT_SIZE = 4;
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
const static int64_t EPS_ATTR_INDEX = 8;
const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast<int64_t>(16 * 1024 * 1024); // 预留16M空间
class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass {
public:
explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
{
Reset();
}
~MoeGatingTopKTilingRegbase() override = default;
void Reset(gert::TilingContext *context) override
{
TilingBaseClass::Reset(context);
Reset();
}
protected:
bool IsCapable() override
{
if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) {
return false;
}
return true;
}
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
ge::graphStatus GetPlatformInfo() override;
// 2、获取INPUT/OUTPUT/ATTR信息
ge::graphStatus GetShapeAttrsInfo() override;
// 3、计算数据切分TilingData
ge::graphStatus DoOpTiling() override;
// 4、计算高阶API的TilingData
ge::graphStatus DoLibApiTiling() override;
// 5、计算TilingKey
uint64_t GetTilingKey() const override;
// 6、计算Workspace 大小
ge::graphStatus GetWorkspaceSize() override;
// 7、保存Tiling数据
ge::graphStatus PostTiling() override;
void Reset();
private:
ge::graphStatus CheckInputShape();
ge::graphStatus CheckAttr();
ge::graphStatus CheckOutShape();
void CalTmpBufUbSize();
void SplitRows();
void Tiling4GatherOutComputeSplitK();
const gert::Shape *xShape_ = nullptr;
const gert::Shape *biasShape_ = nullptr;
const gert::Shape *yShape_ = nullptr;
const gert::Shape *expertIdxShape_ = nullptr;
const gert::Shape *outShape_ = nullptr;
int64_t rows_;
int64_t expertCount_;
int64_t addBias_ = 0;
int64_t k_;
int64_t kGroup_ = 1;
int64_t groupCount_ = 1;
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
int64_t renorm_ = RENORM_NO;
int64_t normType_ = NORM_TYPE_SOFTMAX;
int64_t outFlag_ = OUT_FLAG_FALSE;
float routedScalingFactor_ = 1.0;
float eps_ = 1e-20f;
int64_t inputDtypeSize_;
const char *opName_ = "";
MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_;
platform_ascendc::SocVersion socVersion;
};
ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape()
{
size_t xDimNum = xShape_->GetDimNum();
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
return ge::GRAPH_FAILED);
// 通过输入获取rows 和 expertCount
rows_ = xShape_->GetDim(0);
expertCount_ = xShape_->GetDim(1);
moeGatingTopKTilingData_.set_rowCount(rows_);
moeGatingTopKTilingData_.set_expertCount(expertCount_);
OP_CHECK_IF(
expertCount_ > MAX_EXPERT_COUNT,
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
if (biasShape_ != nullptr) {
addBias_ = 1;
size_t biasDimNum = biasShape_->GetDimNum();
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
return ge::GRAPH_FAILED);
OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_,
OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.",
biasShape_->GetDim(0), expertCount_),
return ge::GRAPH_FAILED);
}
moeGatingTopKTilingData_.set_addBias(addBias_);
OP_CHECK_IF(k_ > expertCount_,
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr()
{
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ > groupCount_,
OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_, k_),
return ge::GRAPH_FAILED);
if (kGroup_ == groupCount_ || groupCount_ == expertCount_) {
kGroup_ = 1;
groupCount_ = 1;
}
moeGatingTopKTilingData_.set_kGroup(kGroup_);
moeGatingTopKTilingData_.set_groupCount(groupCount_);
int64_t groupExpertCount = expertCount_ / groupCount_;
int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L);
moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_);
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
return ge::GRAPH_FAILED);
OP_CHECK_IF(kGroup_ * groupExpertCount < k_,
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
kGroup_ * groupExpertCount, k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupExpertCount < 1,
OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
return ge::GRAPH_FAILED);
OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2,
OP_LOGE(context_,
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
groupExpertCount, groupSelectMode_),
return ge::GRAPH_FAILED);
OP_CHECK_IF(renorm_ != RENORM_NO,
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
return ge::GRAPH_FAILED);
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo()
{
opName_ = context_->GetNodeName();
// 获取输入shape信息
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
xShape_ = &xShapePtr->GetStorageShape();
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
// 获取输出shape
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
yShape_ = &yShapePtr->GetStorageShape();
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
if (outPtr != nullptr) {
outShape_ = &outPtr->GetStorageShape();
}
auto x = context_->GetInputDesc(X_INPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
auto xDtype = x->GetDataType();
OP_CHECK_IF(
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
if (biasShapePtr != nullptr) {
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
OP_CHECK_IF((biasDtype != xDtype),
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
}
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
auto yDtype = yDesc->GetDataType();
OP_CHECK_IF((yDtype != xDtype),
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
return ge::GRAPH_FAILED);
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
auto expertIdDtype = expertIdDesc->GetDataType();
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
return ge::GRAPH_FAILED);
// 获取属性
auto attrs = context_->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
k_ = *kPtr;
moeGatingTopKTilingData_.set_k(k_);
OP_LOGI(context_, "Attr k is: %ld ", k_);
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
if (kGroupPtr != nullptr) {
kGroup_ = *kGroupPtr;
}
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
if (groupCountPtr != nullptr) {
groupCount_ = *groupCountPtr;
}
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
if (groupSelectModePtr != nullptr) {
groupSelectMode_ = *groupSelectModePtr;
}
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
if (renormPtr != nullptr) {
renorm_ = *renormPtr;
}
moeGatingTopKTilingData_.set_renorm(renorm_);
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
if (normTypePtr != nullptr) {
normType_ = *normTypePtr;
}
moeGatingTopKTilingData_.set_normType(normType_);
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
if (outFlagPtr != nullptr) {
outFlag_ = (*outFlagPtr) ? 1 : 0;
}
moeGatingTopKTilingData_.set_outFlag(outFlag_);
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
if (routedScalingFactorPtr != nullptr) {
routedScalingFactor_ = *routedScalingFactorPtr;
}
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
if (epsPtr != nullptr) {
eps_ = *epsPtr;
}
moeGatingTopKTilingData_.set_eps(eps_);
OP_LOGI(context_, "Attr eps is: %f ", eps_);
auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
if (outFlag_ && outDesc != nullptr) {
auto outDtype = outDesc->GetDataType();
OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT),
OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.",
ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()),
return ge::GRAPH_FAILED);
}
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo()
{
auto platformInfo = context_->GetPlatformInfo();
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
socVersion = ascendcPlatform.GetSocVersion();
uint64_t ubSizePlatForm;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
aicoreParams_.ubSize = ubSizePlatForm;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape()
{
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
if (outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
outShape_->GetDimNum(), xShape_->GetDimNum()),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
outShape_->GetDim(0), outShape_->GetDim(0)),
return ge::GRAPH_FAILED);
}
OP_CHECK_IF((yShape_->GetDim(1) != k_),
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
return ge::GRAPH_FAILED);
if (outFlag_ && outShape_ != nullptr) {
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
xShape_->GetDim(1)),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() {
std::vector<int64_t> shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()};
ge::Shape softmaxShape(shape_vec);
uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true);
AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData);
}
void MoeGatingTopKTilingRegbase::SplitRows()
{
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
if (perCoreRows == 0) {
OP_LOGE(context_, "perCoreRows can't be 0.");
return;
}
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
int64_t vmsCount = 0;
if (kGroup_ > MRGSORT_SIZE) {
int64_t index = MRGSORT_SIZE;
while (index < kGroup_) {
index = index * MRGSORT_SIZE;
vmsCount++;
}
}
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling()
{
auto ret = CheckInputShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckAttr();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = CheckOutShape();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
CalTmpBufUbSize();
SplitRows();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling()
{
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize()
{
// 计算workspace大小
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling()
{
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
currentWorkspace[0] = workspaceSize_;
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
context_->GetRawTilingData()->GetCapacity());
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
return ge::GRAPH_SUCCESS;
}
uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const
{
return MOE_GATING_TOP_K_REGBASE_TILING_KEY;
}
void MoeGatingTopKTilingRegbase::Reset()
{
opName_ = nullptr;
return;
}
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000);
} // namespace optiling

View File

@@ -1,38 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_gating_top_k_tiling_base.cpp
* \brief
*/
#include "moe_gating_top_k_tiling.h"
#include "register/op_def_registry.h"
#include "../tiling_base/tiling_base.h"
#include "../tiling_base/tiling_templates_registry.h"
#include "error_log.h"
#include "kernel_tiling/kernel_tiling.h"
namespace optiling {
static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context)
{
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
}
static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(MoeGatingTopK)
.Tiling(TilingForMoeGatingTopK)
.TilingParse<MoeGatingTopKCompileInfo>(TilingPrepareForMoeGatingTopK);
} // namespace optiling

View File

@@ -1,89 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file common.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_COMMON_H
#define MOE_GATING_TOP_K_COMMON_H
#include "kernel_operator.h"
namespace MoeGatingTopK {
using namespace AscendC;
const float MIN_FP32 = *(float *)(&F32_NEG_INF);
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t REPEAT_BYTES = 256;
constexpr int64_t REPEAT_BLOCKS = 8;
constexpr int32_t CONSTANT_TWO = 2;
constexpr int32_t CONSTANT_THREE = 3;
constexpr int32_t CONSTANT_FOUR = 4;
constexpr int32_t CONSTANT_EIGHT = 8;
constexpr int64_t MERGE_LIST_TWO = 2;
constexpr int64_t MERGE_LIST_THREE = 3;
constexpr int64_t MERGE_LIST_FOUR = 4;
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
constexpr int64_t NORM_TYPE_SOFTMAX = 0;
constexpr int64_t NORM_TYPE_SIGMOID = 1;
__aicore__ inline int64_t Ceil(int64_t a, int64_t b)
{
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes)
{
if (bytes == 0) {
return 0;
}
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
}
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes)
{
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
}
template <typename T>
__aicore__ inline T Min(T a, T b)
{
return a > b ? b : a;
}
template <typename T>
__aicore__ inline T Max(T a, T b)
{
return a < b ? b : a;
}
template <typename T1, typename T2>
__aicore__ inline T1 CeilDiv(T1 x, T2 y)
{
if (y != 0 && x != 0) {
const T1 quotient = x / y;
return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient;
}
return x;
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_COMMON_H

View File

@@ -1,56 +0,0 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -1,63 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k.cpp
* \brief
*/
#include "moe_gating_top_k_e_k_fullload.h"
#include "moe_gating_top_k_without_group.h"
#include "moe_gating_top_k_generalized.h"
#include "error_log.h"
#define TILING_KEY_PER_GROUP_COUNT_32 0
#define TILING_KEY_WITHOUT_GROUP 1
#define TILING_KEY_GENERALIZED 2
using namespace AscendC;
using namespace MoeGatingTopK;
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
if (g_coreType == AIC) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling);
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
const MoeGatingTopKTilingData *__restrict t = &tilingData;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) {
MoeGatingTopKEKFullload<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) {
MoeGatingTopKWithoutGroup<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) {
MoeGatingTopKGenerlized<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
op.Process();
}
}

View File

@@ -1,46 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_apt.cpp
* \brief
*/
#include "arch35/moe_gating_top_k_regbase.h"
using namespace AscendC;
using namespace MoeGatingTopK;
#define TILING_KEY_REGBASE 10000
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
{
if (g_coreType == AIC) {
return;
}
if (workspace == nullptr) {
return;
}
GM_ADDR userWS = GetUserWorkspace(workspace);
if (userWS == nullptr) {
return;
}
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling);
const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in;
TPipe tPipe;
if (TILING_KEY_IS(TILING_KEY_REGBASE)) {
MoeGatingTopKRegbase<DTYPE_X> op;
op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe);
op.Process();
}
}

View File

@@ -1,404 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_e_k_fullload.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H
#define MOE_GATING_TOP_K_E_K_FULLLOAD_H
#include "kernel_operator.h"
#include "common.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKEKFullload {
public:
__aicore__ inline MoeGatingTopKEKFullload(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBias();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t progress);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TBuf<TPosition::VECCALC> biasInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TQue<QuePosition::VECOUT, 1> xBiasQueue_;
TQue<QuePosition::VECOUT, 1> xSigmoidQueue_;
TQue<QuePosition::VECIN, 1> sigmoidTmpQueue_;
TQue<QuePosition::VECIN, 1> sortedInGroupQueue_;
TQue<QuePosition::VECIN, 1> sortedGroupQueue_;
TBuf<TPosition::VECCALC> calcTmpBuffer_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<T> outGm_;
int64_t blockIdx_;
int64_t perCoreRowCount_;
int64_t curCoreRowCount_;
int64_t expertCount_;
bool addBias_;
int64_t k_;
int64_t kGroup_;
int64_t groupCount_;
int64_t groupSelectMode_;
int64_t renorm_;
int64_t normType_;
int64_t outFlag_;
float routedScalingFactor_;
float eps_;
int64_t expertCountAlign_;
int64_t kAlign_;
int64_t perGroupExpertCount_;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInBias()
{
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE, expertCount_);
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::ComputeX()
{
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.AllocTensor<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xBiasTensor = xBiasQueue_.AllocTensor<float>();
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
LocalTensor<uint8_t> sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor<uint8_t>(); // 临时空间可以复用
Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_);
PipeBarrier<PIPE_V>();
if (addBias_) {
Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_);
} else {
Adds(xBiasTensor, xSigmoidTensor, static_cast<float>(0), expertCount_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
xBiasQueue_.EnQue<float>(xBiasTensor);
xInQueue_.FreeTensor(xInLocalTensor);
sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SortInGroup()
{
LocalTensor<float> xBiasTensor = xBiasQueue_.DeQue<float>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.AllocTensor<float>(); // 组内排序的结果, 后续归并需要
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>(); // 用于存储排序时的索引
ArithProgression(indexTensor.ReinterpretCast<int32_t>(), 0, 1, expertCount_); // 生成组索引0 1 2 ......
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
xBiasQueue_.FreeTensor(xBiasTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>();
LocalTensor<float> top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor<float>(); // 这个临时空间可以复用
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
indexTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
indexTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = 8;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = 8;
gatherMaskParams.src1RepeatStride = 0;
GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast<uint32_t>(64),
gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
LocalTensor<float> groupTop2SumTensor = top2ValueInGroupTensor;
PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1,
1); // 计算每个组内最大的两个数之和
PipeBarrier<PIPE_V>();
LocalTensor<uint32_t> groupIndexTensor = indexTensor;
ArithProgression(groupIndexTensor.ReinterpretCast<int32_t>(), 0, 1, groupCount_); // 生成组索引
PipeBarrier<PIPE_V>();
// 用最小值补到32个数
int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << groupCount_;
uint64_t mask[2] = {mask0, 0};
Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
// 排序将kgroup选出来
LocalTensor<float> sortedGroupTensor = sortedGroupQueue_.AllocTensor<float>();
Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1);
PipeBarrier<PIPE_V>();
LocalTensor<int32_t> sortedGroupIndexTensor = indexTensor.ReinterpretCast<int32_t>();
// 提取组序号
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
// 需要将组排序(这里是降序所以下mrgsor的时候反着取3、2、1、0)
Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_);
PipeBarrier<PIPE_V>();
duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX << kGroup_;
uint64_t mask[2] = {mask0, 0};
Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8);
PipeBarrier<PIPE_V>();
}
Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast<uint32_t>(), 1);
PipeBarrier<PIPE_V>();
src1Pattern = 1;
GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast<uint32_t>(0), {1, 1, 0, 0},
rsvdCnt);
PipeBarrier<PIPE_V>();
Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_);
sortedGroupQueue_.FreeTensor(sortedGroupTensor);
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKGroupIndexTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
LocalTensor<float> sortedExpertTensor = xInQueue_.AllocTensor<float>();
AscendC::MrgSort4Info params;
params.elementLengths[0] = k_;
params.elementLengths[1] = k_;
params.elementLengths[2] = k_;
params.elementLengths[3] = k_;
params.ifExhaustedSuspension = true;
params.validBit = 0b1111;
params.repeatTimes = 1;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2;
int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2;
int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2;
int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
AscendC::MrgSortSrcList<float> srcList;
srcList.src1 = sortedInGroupTensor[listOffset1];
srcList.src2 = sortedInGroupTensor[listOffset2];
srcList.src3 = sortedInGroupTensor[listOffset3];
srcList.src4 = sortedInGroupTensor[listOffset4];
MrgSort<float>(sortedExpertTensor, srcList, params);
PipeBarrier<PIPE_V>();
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
xInQueue_.FreeTensor(sortedExpertTensor);
expertIdxOutQueue_.EnQue(expertIdxTensor);
sortedInGroupQueue_.FreeTensor(sortedInGroupTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertScore()
{
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<int32_t> expertByteIdxTensor = calcTmpBuffer_.Get<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
LocalTensor<T> yTensor = yOutQueue_.AllocTensor<T>();
LocalTensor<float> yOutTensor;
if constexpr (!IsSameType<T, float>::value) {
yOutTensor = yTensor.template ReinterpretCast<float>()[kAlign_];
} else {
yOutTensor = yTensor;
}
Muls(expertByteIdxTensor, expertIdxTensor, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast<uint32_t>(),
static_cast<uint32_t>(0), k_);
LocalTensor<float> calTensor = calcTmpBuffer_.Get<float>();
PipeBarrier<PIPE_V>();
ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = calTensor.GetValue(0) + eps_;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(calTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, calTensor, k_);
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_);
}
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxTensor);
yOutQueue_.EnQue(yTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams);
xSigmoidQueue_.FreeTensor(xSigmoidTensor);
expertIdxOutQueue_.FreeTensor(expertIdxTensor);
yOutQueue_.FreeTensor(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
routedScalingFactor_ = tilingData_->routedScalingFactor;
eps_ = tilingData_->eps;
expertCountAlign_ = Align(expertCount_, sizeof(float));
kAlign_ = Align(expertCount_, sizeof(float));
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t)));
pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2);
pipe_->InitBuffer(sortedGroupQueue_, 2,
(groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM *
sizeof(float) * 2);
pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize);
}
template <typename T>
__aicore__ inline void MoeGatingTopKEKFullload<T>::Process()
{
CopyInBias();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H

View File

@@ -1,669 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_generalized.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H
#define MOE_GATING_TOP_K_E_K_GENERALIZED_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKGenerlized {
public:
__aicore__ inline MoeGatingTopKGenerlized(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SortInGroup();
__aicore__ inline void SelectTopKGroupIndex();
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CumputeActualTopKExpertId();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
TBuf<TPosition::VECCALC> sortedInGroupBuf_; // 存放组内排序后的结果
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> sortedGroupIndexBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
int64_t k_ = 0;
int64_t kGroup_ = 0;
int64_t groupCount_ = 0;
int64_t groupCountAlign_ = 0;
int64_t perGroupExpertCount_ = 0;
int64_t perGroupExpertCountAlign_ = 0;
int64_t groupSelectMode_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t outFlag_ = 0;
int64_t expertCountAlign_ = 0;
int64_t kAlign_ = 0;
bool isAlign_ = false;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (!isAlign_) {
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(biasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCountAlign_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = groupCount_;
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xInLocalTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
(perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
else if (normType_ == 0) { // softmax
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCountAlign_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
if (!isAlign_ && duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex],
FLOAT32_NEG_INF, // MIN_FP32,
mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{
static_cast<uint16_t>(groupCount_), static_cast<uint32_t>(perGroupExpertCount_ * sizeof(float)),
static_cast<uint32_t>((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SortInGroup()
{
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> tmpLocal = calcTmpBuf_.Get<float>();
if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) {
PipeBarrier<PIPE_V>();
Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_);
} else {
for (int64_t group = 0; group < groupCount_; group++) {
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO],
xNormWithBiasTensor[group * perGroupExpertCountAlign_],
expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal,
perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM);
}
}
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKGroupIndex()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<float> valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, 0);
LocalTensor<uint32_t> maskTensor =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float));
LocalTensor<float> topValueInGroupTensor =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float));
LocalTensor<uint32_t> groupIndex =
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float));
LocalTensor<float> sortedTopValue =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float));
LocalTensor<float> sortTmp =
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float));
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
PipeBarrier<PIPE_V>();
if (groupSelectMode_ == 1) { // top2 sum
// 提取每组组前两个元素
maskTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride =
Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// 计算每个组前两个数的和
PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor,
Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1,
CONSTANT_EIGHT); // 计算每个组内最大的两个数之和
} else {
maskTensor.SetValue(0, static_cast<uint32_t>(1)); // b0101
maskTensor.SetValue(1, static_cast<uint32_t>(0));
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = groupCount_;
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32);
gatherMaskParams.src1RepeatStride = 0;
GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true,
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
}
PipeBarrier<PIPE_V>();
// 生成组索引
ArithProgression(groupIndex.ReinterpretCast<int32_t>(), static_cast<int32_t>(0), static_cast<int32_t>(1),
groupCount_); // 生成组索引
PipeBarrier<PIPE_V>();
int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = groupCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(topValueInGroupTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1,
REPEAT_BLOCKS);
PipeBarrier<PIPE_V>();
}
PipeBarrier<PIPE_V>();
// 排序
Sort<float, true>(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32));
PipeBarrier<PIPE_V>();
// 提取组序号
uint8_t src1Pattern = 2; // 内置固定模式
GatherMask(groupIndex, sortedTopValue.template ReinterpretCast<uint32_t>(), src1Pattern, false,
static_cast<uint32_t>(0),
{1, static_cast<uint8_t>(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
duplicateIndex = kGroup_ - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
PipeBarrier<PIPE_V>();
Duplicate(groupIndex.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS);
}
// 将筛选出来的组序号降序排列
LocalTensor<float> sortedGroupIndex = sortedGroupIndexBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Sort<float, true>(sortedGroupIndex, groupIndex.ReinterpretCast<float>(), groupIndex, sortTmp, Ceil(kGroup_, 32));
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertIdx()
{
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
LocalTensor<int32_t> sortedGroupIndex = sortedGroupIndexBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> mrgSort0Tensor = calcTmpBuf_.Get<float>();
uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0};
uint16_t lenArr[CONSTANT_FOUR] = {
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_),
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_)};
MrgSort4Info params{lenArr, false, 0b1111, 1};
MrgSortSrcList<float> srcList;
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) {
int64_t mrgLen = Min(i + 1, CONSTANT_FOUR);
if (mrgLen > 1) {
if (mrgLen == MERGE_LIST_FOUR) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2;
} else if (mrgLen == MERGE_LIST_THREE) {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
offset[3] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
offset[2] = 0;
offset[3] = 0;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
srcList.src1 = sortedInGroupTensor[offset[0]];
srcList.src2 = sortedInGroupTensor[offset[1]];
srcList.src3 = sortedInGroupTensor[offset[2]];
srcList.src4 = sortedInGroupTensor[offset[3]];
PipeBarrier<PIPE_V>();
MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params);
} else {
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
PipeBarrier<PIPE_V>();
DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]],
perGroupExpertCountAlign_ * 2);
}
}
int32_t baseLoop = 4;
LocalTensor<float> srcTensor = mrgSort0Tensor;
LocalTensor<float> dstTensor = mrgSort0Tensor;
for (int i = 0; i < tilingData_->vmsCount; i++) {
if (i % 2 == 0) {
srcTensor = mrgSort0Tensor;
dstTensor = sortedInGroupTensor;
} else {
srcTensor = sortedInGroupTensor;
dstTensor = mrgSort0Tensor;
}
int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR;
int32_t quotient = kGroup_ / nextBaseRow;
int32_t remainder = kGroup_ - quotient * nextBaseRow;
if (quotient > 0) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
params.validBit = 0b1111;
params.repeatTimes = 1;
for (int j = 0; j < quotient; j++) {
srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j];
srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)];
srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)];
srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)];
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params);
}
}
if (remainder > 0) {
int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2;
int32_t mrgLen = CeilDiv(remainder, baseLoop);
int32_t tailRow = remainder - (mrgLen - 1) * baseLoop;
if (mrgLen > 1) {
MrgSort4Info params;
MrgSortSrcList<float> srcList;
params.repeatTimes = 1;
params.ifExhaustedSuspension = false;
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
srcList.src1 = srcTensor[baseOffset];
srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2];
if (mrgLen == MERGE_LIST_FOUR) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3];
params.elementLengths[3] = perGroupExpertCount_ * tailRow;
params.validBit = 0b1111;
} else if (mrgLen == MERGE_LIST_THREE) {
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
params.elementLengths[2] = perGroupExpertCount_ * tailRow;
params.elementLengths[3] = 0;
params.validBit = 0b111;
} else {
params.elementLengths[1] = perGroupExpertCount_ * tailRow;
params.elementLengths[2] = 0;
params.elementLengths[3] = 0;
params.validBit = 0b11;
}
PipeBarrier<PIPE_V>();
MrgSort(dstTensor[baseOffset], srcList, params);
} else {
PipeBarrier<PIPE_V>();
DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2);
}
}
baseLoop = nextBaseRow;
}
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, dstTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // 情况1sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // 情况3softmax + renorm
if (needRenorm) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[32];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CumputeActualTopKExpertId()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> topKExpertIdFp32 = calcTmpBuf_.Get<float>();
PipeBarrier<PIPE_V>();
Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_);
PipeBarrier<PIPE_V>();
Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_);
PipeBarrier<PIPE_V>();
Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_);
PipeBarrier<PIPE_V>();
Muls(expertIdxOut, expertIdxOut, static_cast<int32_t>(perGroupExpertCountAlign_ - perGroupExpertCount_), k_);
PipeBarrier<PIPE_V>();
Sub(expertIdxOut, topKExpertId, expertIdxOut, k_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
k_ = tilingData_->k;
kGroup_ = tilingData_->kGroup;
groupCount_ = tilingData_->groupCount;
groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
groupSelectMode_ = tilingData_->groupSelectMode;
expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float));
kAlign_ = Align(k_, sizeof(float));
isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t)));
pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO);
pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t));
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10);
}
template <typename T>
__aicore__ inline void MoeGatingTopKGenerlized<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (tilingData_->outFlag) {
CopuOutXNorm(row);
}
SortInGroup();
SelectTopKGroupIndex();
SelectTopKExpertIdx();
SelectTopKExpertScore();
CumputeActualTopKExpertId();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H

View File

@@ -1,338 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_gating_top_k_without_group.h
* \brief
*/
#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
#include "kernel_operator.h"
#include "common.h"
#include "kernel_utils.h"
namespace MoeGatingTopK {
using namespace AscendC;
template <typename T>
class MoeGatingTopKWithoutGroup {
public:
__aicore__ inline MoeGatingTopKWithoutGroup(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInBiasAndInitExpertId();
__aicore__ inline void CopyInX(int64_t progress);
__aicore__ inline void ComputeX();
__aicore__ inline void CopuOutXNorm(int64_t row);
__aicore__ inline void SelectTopKExpertIdx();
__aicore__ inline void SelectTopKExpertScore();
__aicore__ inline void CopyOut(int64_t row);
private:
TPipe *pipe_;
TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> yOutQueue_;
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
TQue<QuePosition::VECOUT, 1> outOutQueue_;
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
TBuf<TPosition::VECCALC> calcTmpBuf_;
GlobalTensor<T> xGm_;
GlobalTensor<T> biasGm_;
GlobalTensor<T> yGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> outGm_;
int64_t blockIdx_ = 0;
int64_t perCoreRowCount_ = 0;
int64_t curCoreRowCount_ = 0;
int64_t expertCount_ = 0;
bool addBias_ = false;
bool outFlag_ = false;
int64_t k_ = 0;
int64_t renorm_ = 0;
int64_t normType_ = 0;
int64_t expertCountAlign_ = 0;
const MoeGatingTopKTilingData *tilingData_;
};
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInBiasAndInitExpertId()
{
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if (addBias_) {
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
} else {
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCountAlign_);
PipeBarrier<PIPE_V>();
}
}
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCount_);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInX(int64_t row)
{
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
} else {
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
dataCopyPadParams);
}
xInQueue_.EnQue(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::ComputeX()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
expertCount_);
PipeBarrier<PIPE_V>();
}
if (normType_ == 1) { // sigmoid
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_);
PipeBarrier<PIPE_V>();
} else if (normType_ == 0) { // sigmoid
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[8];
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float maxValue = reduceValueTensor.GetValue(0);
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_);
PipeBarrier<PIPE_V>();
Exp(xNormTensor, xNormTensor, expertCount_);
PipeBarrier<PIPE_V>();
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_);
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = reduceValueTensor.GetValue(0);
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_);
PipeBarrier<PIPE_V>();
}
if (addBias_) {
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_);
} else {
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
}
int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM;
int duplicateIndex = expertCount_ - duplicateNum;
if (duplicateNum > 0) {
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1);
PipeBarrier<PIPE_V>();
}
xInQueue_.FreeTensor(xInLocalTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopuOutXNorm(int64_t row)
{
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
outOutQueue_.EnQue<float>(outOutTensor);
outOutTensor = outOutQueue_.DeQue<float>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(float)), 0, 0, 0};
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
outOutQueue_.FreeTensor(outOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertIdx()
{
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<float> sortedScore = calcTmpBuf_.Get<float>();
LocalTensor<float> sortTmp = calcTmpBuf_.Get<float>()[expertCountAlign_ * CONSTANT_TWO];
PipeBarrier<PIPE_ALL>();
Sort<float, true>(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp,
expertCountAlign_ / ONE_REPEAT_SORT_NUM);
GatherMaskParams gatherMaskParams;
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES);
gatherMaskParams.src0BlockStride = 1;
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
gatherMaskParams.src1RepeatStride = 0;
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
uint8_t src1Pattern = 2; // 内置固定模式
PipeBarrier<PIPE_V>();
GatherMask(topKExpertId, sortedScore.template ReinterpretCast<int32_t>(), src1Pattern, false,
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
DataCopy(expertIdxOut, topKExpertId, expertCountAlign_);
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertScore()
{
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
PipeBarrier<PIPE_V>();
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
PipeBarrier<PIPE_V>();
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
k_);
bool needRenorm = (normType_ == 1 ) || // 情况1sigmoid + renorm
(normType_ == 0 && renorm_ == 1); // 情况3softmax + renorm
if (needRenorm == 1) {
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
PipeBarrier<PIPE_V>();
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
SetFlag<HardEvent::V_S>(eventIdVToS);
WaitFlag<HardEvent::V_S>(eventIdVToS);
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventIdSToV);
WaitFlag<HardEvent::S_V>(eventIdSToV);
Duplicate(tmpTensor, sumValue, k_);
PipeBarrier<PIPE_V>();
Div(yOutTensor, yOutTensor, tmpTensor, k_);
}
PipeBarrier<PIPE_V>();
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
if constexpr (!IsSameType<T, float>::value) {
PipeBarrier<PIPE_V>();
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
}
yOutQueue_.EnQue<float>(yOutTensor);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyOut(int64_t row)
{
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
dataCopyParams.blockLen = k_ * sizeof(int32_t);
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
yOutQueue_.FreeTensor(yOutTensor);
expertIdxOutQueue_.FreeTensor(expertIdxOut);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
GM_ADDR out, GM_ADDR workspace,
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
{
tilingData_ = tilingData;
pipe_ = tPipe;
blockIdx_ = GetBlockIdx();
perCoreRowCount_ = tilingData_->perCoreRowCount;
if (blockIdx_ == GetBlockNum() - 1) {
curCoreRowCount_ = tilingData_->lastCoreRowCount;
} else {
curCoreRowCount_ = tilingData_->perCoreRowCount;
}
expertCount_ = tilingData_->expertCount;
addBias_ = tilingData_->addBias == 1;
outFlag_ = tilingData_->outFlag == 1;
k_ = tilingData_->k;
renorm_ = tilingData_->renorm;
normType_ = tilingData_->normType;
expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
// init input gm buf
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
// init output gm buf
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
// init que
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float));
pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t));
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
// init calc buf
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t));
// init tmp buf
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT);
}
template <typename T>
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Process()
{
CopyInBiasAndInitExpertId();
for (int64_t row = 0; row < curCoreRowCount_; row++) {
CopyInX(row);
ComputeX();
if (outFlag_) {
CopuOutXNorm(row);
}
SelectTopKExpertIdx();
SelectTopKExpertScore();
CopyOut(row);
}
}
} // namespace MoeGatingTopK
#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H

View File

@@ -1,51 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling.h
* \brief
*/
#pragma once
#include <vector>
#include <graph/tensor.h>
#include "data_copy_transpose_tiling_def.h"
namespace optiling {
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
optiling::CopyTransposeTiling &tiling)
{
constexpr int64_t B_INDEX = 0;
constexpr int64_t N_INDEX = 1;
constexpr int64_t S_INDEX = 2;
constexpr int64_t H_INDEX = 3;
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
}
} // namespace optiling

View File

@@ -1,43 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file data_copy_transpose_tiling_def.h
* \brief
*/
#pragma once
#include <cstdint>
#include <register/tilingdata_base.h>
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
} // namespace optiling

View File

@@ -1,56 +0,0 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
#define OP_CHECK_IF(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if ((ptr) == nullptr) { \
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -1,256 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_base.h
* \brief
*/
#pragma once
#include <sstream>
#include <exe_graph/runtime/tiling_context.h>
#include <graph/utils/type_utils.h>
#include "tiling/platform/platform_ascendc.h"
#include "error_log.h"
#ifdef ASCENDC_OP_TEST
#define ASCENDC_EXTERN_C extern "C"
#else
#define ASCENDC_EXTERN_C
#endif
namespace Ops {
namespace Transformer {
namespace OpTiling {
struct AiCoreParams {
uint64_t ubSize = 0;
uint64_t blockDim = 0;
uint64_t aicNum = 0;
uint64_t l1Size = 0;
uint64_t l0aSize = 0;
uint64_t l0bSize = 0;
uint64_t l0cSize = 0;
};
struct CompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
struct FlashAttentionScoreGradCompileInfo {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
platform_ascendc::SocVersion socVersion;
};
struct FACompileInfoCommon {
uint32_t aivNum;
uint32_t aicNum;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
uint64_t l2CacheSize;
int64_t coreNum;
int32_t socVersion;
uint32_t rsvd;
};
class TilingBaseClass {
public:
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
{}
virtual ~TilingBaseClass() = default;
// Tiling执行框架
// 1、GRAPH_SUCCESS: 成功并且不需要继续执行后续Tiling类的实现
// 2、GRAPH_FAILED: 失败中止整个Tiling流程
// 3、GRAPH_PARAM_INVALID: 本类不支持需要继续往下执行其他Tiling类的实现
ge::graphStatus DoTiling()
{
auto ret = GetShapeAttrsInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
if (!IsCapable()) {
return ge::GRAPH_PARAM_INVALID;
}
ret = DoOpTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = DoLibApiTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = GetWorkspaceSize();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
ret = PostTiling();
if (ret != ge::GRAPH_SUCCESS) {
return ret;
}
context_->SetTilingKey(GetTilingKey());
DumpTilingInfo();
return ge::GRAPH_SUCCESS;
}
// 更新 context
virtual void Reset(gert::TilingContext* context)
{
context_ = context;
}
protected:
virtual bool IsCapable() = 0;
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
virtual ge::graphStatus GetPlatformInfo() = 0;
// 2、获取INPUT/OUTPUT/ATTR信息
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
// 3、计算数据切分TilingData
virtual ge::graphStatus DoOpTiling() = 0;
// 4、计算高阶API的TilingData
virtual ge::graphStatus DoLibApiTiling() = 0;
// 5、计算TilingKey
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
// 6、计算Workspace 大小
virtual ge::graphStatus GetWorkspaceSize() = 0;
// 7、保存Tiling数据
virtual ge::graphStatus PostTiling() = 0;
// 8、Dump Tiling数据
virtual void DumpTilingInfo()
{
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
if (enable != 1) {
return;
}
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
auto bufLen = context_->GetRawTilingData()->GetDataSize();
std::ostringstream oss;
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
<< ", content:";
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
oss << *(buf + i) << ",";
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
OP_LOGD(context_, "%s", oss.str().c_str());
oss.str("");
}
}
OP_LOGD(context_, "%s", oss.str().c_str());
}
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
{
uint32_t ration;
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
return sliceNum;
}
ration = aivCoreNum / aicCoreNum;
return (sliceNum + (ration - 1)) / ration;
}
template <typename T>
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
{
std::ostringstream oss;
oss << "[";
if (shape.GetDimNum() > 0) {
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
oss << shape.GetDim(i) << ", ";
}
oss << shape.GetDim(shape.GetDimNum() - 1);
}
oss << "]";
return oss.str();
}
[[nodiscard]] std::string GetTensorDebugStr(
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
{
if (shape == nullptr || tensor == nullptr) {
return "nil ";
}
std::ostringstream oss;
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
oss << "(format: "
<< ge::TypeUtils::FormatToSerialString(
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
<< "),";
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
return oss.str();
}
[[nodiscard]] std::string GetTilingContextDebugStr()
{
std::ostringstream oss;
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
oss << "input" << i << ": ";
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
}
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
oss << "output" << i << ": ";
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
}
return oss.str();
}
[[nodiscard]] std::string GetTilingDataDebugStr() const
{
auto rawTilingData = context_->GetRawTilingData();
auto rawTilingDataSize = rawTilingData->GetDataSize();
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
size_t len = rawTilingDataSize / sizeof(int32_t);
std::ostringstream oss;
for (size_t i = 0; i < len; i++) {
oss << data[i] << ", ";
}
return oss.str();
}
protected:
gert::TilingContext* context_ = nullptr;
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -1,63 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_key.h
* \brief
*/
#pragma once
#include <cstdint>
namespace Ops {
namespace Transformer {
namespace OpTiling {
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr uint64_t kBase = 10; // 10进制进位基数
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
}
// TilingKey 的生成规则:
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key包含以下关键参数从低位到高位依次是Ub0, Ub1,
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
// 表示Ub核内切分的轴使用枚举AxisEnum表示因为我们允许最多切分两根轴所以存在UB0和UB1如果没有UB核内切分
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
// Block: 表示UB用来分核的轴使用枚举AxisEnum表示占一个十进制位;
// DataType: 表示当前tiling key支持的输入输出的数据类型使用枚举SupportedDtype来表示占一个十进制位
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示占一个十进制位
// Sparse: 表示当前tiling key是否支持Sparse使用枚举SparseCapability表示占一个十进制位
// 其余特化场景,定义自己的位域和值
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace Optiling
} // namespace Transformer
} // namespace Ops

View File

@@ -1,351 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_templates_registry.h
* \brief
*/
#pragma once
#include <map>
#include <string>
#include <memory>
#include "exe_graph/runtime/tiling_context.h"
#include "tiling_base.h"
#include "error_log.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
template <typename T>
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
{
return std::unique_ptr<T>(new (std::nothrow) T(context));
}
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
class TilingCases {
public:
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
void AddTiling(int32_t priority)
{
OP_CHECK_IF(
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
cases_[priority] = TILING_CLASS<T>;
OP_CHECK_IF(
cases_[priority] == nullptr,
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
}
const std::map<int32_t, TilingClassCase>& GetTilingCases()
{
return cases_;
}
private:
std::map<int32_t, TilingClassCase> cases_;
const std::string op_type_;
};
// --------------------------------Interfacce with soc version --------------------------------
class TilingRegistryNew {
public:
TilingRegistryNew() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistryNew& GetInstance();
#else
static TilingRegistryNew& GetInstance()
{
static TilingRegistryNew registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
if (soc_iter == registry_map_.end()) {
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
registry_map_[soc_version] = op_type_map;
} else {
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
}
OP_CHECK_IF(
registry_map_[soc_version][op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[soc_version][op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
const char* op_type = context->GetNodeType();
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
return ge::GRAPH_FAILED;
}
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
int32_t soc_version;
const char* op_type = context->GetNodeType();
auto platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
OP_CHECK_IF(
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
soc_version = compileInfoPtr->socVersion;
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
} else {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
OP_LOGD(context, "soc version is %d", soc_version);
}
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
for (auto priority_id : priorities) {
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
auto templateFunc = tilingCaseIter->second(context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
}
}
}
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
{
auto soc_iter = registry_map_.find(soc_version);
OP_CHECK_IF(
soc_iter == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
return empty_tiling_case_);
auto op_iter = soc_iter->second.find(op_type);
OP_CHECK_IF(
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
return empty_tiling_case_);
return op_iter->second->GetTilingCases();
}
private:
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
};
class RegisterNew {
public:
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
RegisterNew& tiling(int32_t priority, int32_t soc_version)
{
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
template <typename T>
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
{
for (int32_t soc_version : soc_versions) {
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
return *this);
tilingCases->AddTiling<T>(priority);
}
return *this;
}
private:
const std::string op_type_;
};
// --------------------------------Interfacce without soc version --------------------------------
class TilingRegistry {
public:
TilingRegistry() = default;
#ifdef ASCENDC_OP_TEST
static TilingRegistry& GetInstance();
#else
static TilingRegistry& GetInstance()
{
static TilingRegistry registry_impl_;
return registry_impl_;
}
#endif
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
{
if (registry_map_.find(op_type) == registry_map_.end()) {
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
}
OP_CHECK_IF(
registry_map_[op_type] == nullptr,
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
return registry_map_[op_type];
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
auto tilingTemplate = it->second(context);
if (tilingTemplate != nullptr) {
ge::graphStatus status = tilingTemplate->DoTiling();
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
{
const char* op_type = context->GetNodeType();
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
for (auto priorityId : priorities) {
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
if (templateFunc != nullptr) {
ge::graphStatus status = templateFunc->DoTiling();
if (status == ge::GRAPH_SUCCESS) {
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
return status;
}
if (status != ge::GRAPH_PARAM_INVALID) {
OP_LOGD(context, "Do op tiling failed");
return status;
}
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
}
}
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
return ge::GRAPH_FAILED;
}
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
{
OP_CHECK_IF(
registry_map_.find(op_type) == registry_map_.end(),
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
return registry_map_[op_type]->GetTilingCases();
}
private:
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
};
class Register {
public:
explicit Register(std::string op_type) : op_type_(std::move(op_type))
{}
template <typename T>
Register& tiling(int32_t priority)
{
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
OP_CHECK_IF(
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
tilingCases->AddTiling<T>(priority);
return *this;
}
private:
const std::string op_type_;
};
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops
// op_type: 算子名称, class_name: 注册的 tiling 类, soc_version芯片版本号
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
// op_type: 算子名称, class_name: 注册的 tiling 类,
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
// op_type: 算子名称, class_name: 注册的 tiling 类,
// soc_version: soc版本用于区分不同的soc
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
// op_type: 算子名称, class_name: 注册的 tiling 类,
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
// 取代 REGISTER_TILING_TEMPLATE , 传入的op_type如果是字符串常量需要去掉引号
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
static Ops::Transformer::OpTiling::Register \
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)

View File

@@ -1,139 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_type.h
* \brief
*/
#pragma once
#include <cstdint>
namespace optiling {
enum class AxisEnum {
B = 0,
N2 = 1,
G = 2,
S1 = 3,
S2 = 4,
D = 5,
NONE = 9,
};
enum class DtypeEnum {
FLOAT16 = 0,
FLOAT32 = 1,
BFLOAT16 = 2,
FLOAT16_PRECISION = 3,
};
enum class PerformanceOrientedEnum {
BIG_BUFFER = 1,
BIG_DOUBLE_BUFFER = 2,
};
enum class MatmulConfig {
NULL_CONFIG = 0,
NORMAL_CONFIG = 1,
MDL_CONFIG = 2
};
enum class PseConfig {
NO_PSE = 0,
EXIST_PSE = 1
};
enum class AttenMaskConfig {
NO_ATTEN_MASK = 0,
EXIST_ATTEN_MASK = 1
};
enum class DropOutConfig {
NO_DROP_OUT = 0,
EXIST_DROP_OUT = 1
};
enum class CubeFormatEnum {
ND = 0,
NZ = 1
};
enum class LayoutEnum {
BSND = 0,
SBND = 1,
BNSD = 2,
TND = 3,
NTD_TND = 4
};
enum class CubeInputSourceEnum {
GM = 0,
L1 = 1
};
enum class OptionEnum {
DISABLE = 0,
ENABLE = 1
};
enum class SparseEnum {
ALL = 0,
NONE = 1,
ANY = 2,
CAUSAL = 3,
BAND = 4,
PREFIX = 5,
BAND_COMPRESS = 6,
RIGHT_DOWN_CAUSAL = 7,
RIGHT_DOWN_CAUSAL_BAND = 8,
BAND_LEFT_UP_CAUSAL = 9
};
constexpr uint64_t RecursiveSum()
{
return 0;
}
constexpr int64_t base10Multiplier = 10;
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
{
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
}
// TilingKey 的生成规则:
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key包含以下关键参数从低位到高位依次是Ub0, Ub1,
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
// 表示Ub核内切分的轴使用枚举AxisEnum表示因为我们允许最多切分两根轴所以存在UB0和UB1如果没有UB核内切分
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
// Block: 表示UB用来分核的轴使用枚举AxisEnum表示占一个十进制位;
// DataType: 表示当前tiling key支持的输入输出的数据类型使用枚举SupportedDtype来表示占一个十进制位
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示占一个十进制位
// Sparse: 表示当前tiling key是否支持Sparse使用枚举SparseCapability表示占一个十进制位
// 其余特化场景,定义自己的位域和值
// usage: get tilingKey from inputed types
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
{
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
}
// usage: get tilingKey from inputed types
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
SparseEnum::sparse))
} // namespace optiling

View File

@@ -1,30 +0,0 @@
/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file tiling_util.h
* \brief
*/
#pragma once
#include "register/op_impl_registry.h"
namespace Ops {
namespace Transformer {
namespace OpTiling {
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
bool IsRegbaseSocVersion(const gert::TilingContext* context);
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
} // namespace OpTiling
} // namespace Transformer
} // namespace Ops

View File

@@ -1118,60 +1118,6 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons
return combined_x;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
const at::Tensor& x,
int64_t k,
int64_t kGroup,
int64_t groupCount,
int64_t groupSelectMode,
int64_t renorm,
int64_t normType,
bool outFlag,
double routedScalingFactor,
double eps,
const c10::optional<at::Tensor>& biasOptional
)
{
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
auto x_size = x.sizes();
auto rows = x_size[0];
auto expert_num = x_size[1];
const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); });
if (bias.defined()) {
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
auto bias_size = bias.sizes();
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
}
at::Tensor yOut = at::empty({rows, k}, x.options());
at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt));
at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(aclnnMoeGatingTopK,
x, // input_x
biasOptional,
k, // k
kGroup, // k_group
groupCount, // group_count
groupSelectMode, // group_select_mode
renorm, // renorm
normType, // norm_type
outFlag, // out_flag
routedScalingFactor, // routed_scaling_factor
eps, // eps
yOut, // input_y (注意:这里应该是 yOut)
expertIdxOut, // output
outOut
);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(outOut,expertIdxOut, yOut);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom(
const at::Tensor &x, const at::Tensor &expert_idx,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
@@ -1275,25 +1221,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
{
{
// vLLM-Ascend custom ops
ops.def(
"moe_gating_top_k(Tensor x, "
"int k, "
"int kGroup, "
"int groupCount, "
"int groupSelectMode, "
"int renorm, "
"int normType, "
"bool outFlag, "
"float routedScalingFactor, "
"float eps,"
"Tensor? biasOptional=None)"
"-> (Tensor outOut,Tensor expertIdxOut, Tensor yOut)"
);
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
//Moe_gating
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);

View File

@@ -283,42 +283,6 @@ std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm_meta(
return {output, add_out};
}
std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
const at::Tensor& x,
int64_t k,
int64_t kGroup,
int64_t groupCount,
int64_t groupSelectMode,
int64_t renorm,
int64_t normType,
bool outFlag,
double routedScalingFactor,
double eps,
const c10::optional<at::Tensor>& biasOptional)
{
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
auto x_size = x.sizes();
auto rows = x_size[0];
auto expert_num = x_size[1];
const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); });
if (bias.defined()) {
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
auto bias_size = bias.sizes();
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
}
at::Tensor yOut = at::empty({rows, k}, x.options());
at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt));
at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(outOut,expertIdxOut, yOut);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom_meta(
const at::Tensor &x, const at::Tensor &expert_idx,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
@@ -403,15 +367,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
}
} // namespace meta
} // namespace vllm_ascend
namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Moe_gating_top_k
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation

View File

@@ -179,8 +179,7 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=2,
global_bs=self.batch_size * self.ep_world_size,
expert_token_nums_type=
1, # 0 represents prefix sum, 1 represents individual counts
expert_token_nums_type=1, # 0代表前缀和1代表各自数量
)
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
output_dtype = x.dtype
@@ -189,8 +188,8 @@ class SmallOps(DecodeMoeOps):
x=[expand_x],
weight=[self.gmm1_weight],
split_item=3,
group_list_type=1, # Default is 0, represents prefix sum format
group_type=0, # 0 represents m-axis grouping
group_list_type=1, # 默认为0代表前缀和形式
group_type=0, # 0代表m轴分组
group_list=expert_token_nums,
output_dtype=torch.int32)[0]
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
@@ -366,7 +365,7 @@ def run_once(local_rank_id,
with_mc2_mask=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # Single machine
global_rank_id = local_rank_id # 单机
device_id = local_rank_id % 16
torch_npu.npu.set_device(device_id)

View File

@@ -1,322 +0,0 @@
import itertools
import logging
import random
from typing import Optional, Tuple
import numpy as np
import torch
from torch_npu.testing.testcase import TestCase, run_tests
try:
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
except ImportError:
logging.warning(
"vllm_ascend.utils.enable_custom_op not found, skip custom op initialization"
)
def enable_custom_op() -> None:
pass
# Set random seed for reproducibility
SEED = 45
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.manual_seed_all(SEED)
# Configure logging
logging.basicConfig(level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
def softmax_func(
x: np.ndarray,
axis: Optional[int] = None,
eps: float = 1e-20) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Stable softmax implementation for MOE gating.
Args:
x: Input array
axis: Axis to compute softmax
eps: Epsilon to avoid division by zero
Returns:
softmax_output: Softmax result
x_max: Max value for numerical stability
x_sum: Sum of exponentials
"""
if "float16" in x.dtype.name:
x = x.astype(np.float32)
x_max = x.max(axis=axis, keepdims=True)
x_sub = x - x_max
y = np.exp(x_sub)
x_sum = y.sum(axis=axis, keepdims=True)
softmax_output = y / (x_sum + eps)
return softmax_output, x_max, x_sum
class TestNpuMoeGatingTopK(TestCase):
"""Test suite for NPU MOE Gating Top-K operator compatibility."""
def moe_gating_top_k_np(
self,
x: np.ndarray,
k: int,
bias: Optional[np.ndarray] = None,
k_group: int = 1,
group_count: int = 1,
group_select_mode: int = 0,
renorm: int = 0,
norm_type: int = 0,
y2_flag: bool = False,
routed_scaling_factor: float = 1.0,
eps: float = 1e-20
) -> Tuple[torch.Tensor, np.ndarray, Optional[np.ndarray]]:
"""
NumPy reference implementation of MOE gating Top-K logic.
Args:
x: Input features, shape [batch_size, num_experts]
k: Number of experts to select per sample
bias: Gating bias, shape [num_experts]
k_group: Number of groups to select (group mode)
group_count: Number of expert groups
group_select_mode: 0 (max per group), 1 (sum of top-2 per group)
renorm: Whether to renormalize weights (1=enable, 0=disable)
norm_type: 0 (softmax), 1 (sigmoid)
y2_flag: Whether to return original x as y2
routed_scaling_factor: Weight scaling factor
eps: Epsilon for numerical stability
Returns:
y: Selected expert weights (Tensor)
indices: Selected expert indices (int32 numpy array)
y2: Original x if y2_flag=True, else None
"""
# Convert torch tensors to numpy arrays if needed (compatibility layer)
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
if isinstance(bias, torch.Tensor):
bias = bias.cpu().numpy()
# Type conversion for numerical stability
orig_dtype = x.dtype
if orig_dtype != np.float32:
x = x.astype(np.float32)
if bias is not None:
bias = bias.astype(np.float32)
# Apply normalization (softmax/sigmoid)
if norm_type == 0:
x, _, _ = softmax_func(x, axis=-1, eps=eps)
else:
x = 1 / (1 + np.exp(-x)) # Sigmoid
original_x = x.copy()
# Apply bias if provided
if bias is not None:
x = x + bias
# Group-based expert selection
if group_count > 1:
batch_size, num_experts = x.shape
if num_experts % group_count != 0:
raise ValueError(
f"num_experts ({num_experts}) must be divisible by group_count ({group_count})"
)
group_size = num_experts // group_count
# Reshape to [batch, groups, group_size]
x_reshaped = x.reshape(batch_size, group_count, group_size)
# Compute group scores
if group_select_mode == 0:
group_scores = np.amax(x_reshaped, axis=-1)
else:
# Sum of top-2 values per group
group_scores = np.partition(x_reshaped, -2,
axis=-1)[..., -2:].sum(axis=-1)
# Select top-k_group groups
top_groups = np.argsort(-group_scores, axis=-1,
kind="stable")[:, :k_group]
# Mask out non-selected groups with -inf
mask = np.ones((batch_size, group_count), dtype=bool)
mask[np.arange(batch_size)[:, None], top_groups] = False
x_reshaped = np.where(mask[..., None], float("-inf"), x_reshaped)
# Reshape back to original
x = x_reshaped.reshape(batch_size, num_experts)
# Select top-k experts
x_tensor = torch.from_numpy(x)
_, topk_indices = torch.sort(x_tensor,
dim=-1,
stable=True,
descending=True)
topk_indices = np.asarray(topk_indices[:, :k], dtype=np.int32)
# Extract weights for selected experts
selected_weights = np.take_along_axis(original_x, topk_indices, axis=1)
# Apply renormalization if needed
if norm_type == 1 or renorm == 1:
weight_sum = np.sum(selected_weights, axis=-1, keepdims=True)
selected_weights = selected_weights / (weight_sum + eps)
# Apply scaling factor
selected_weights *= routed_scaling_factor
# Prepare y2 output
y2 = original_x if y2_flag else None
# Convert back to torch tensor with original dtype
selected_weights_tensor = torch.tensor(selected_weights,
dtype=orig_dtype)
return selected_weights_tensor, topk_indices, y2
def test_npu_moe_gating_topk_multi(self) -> None:
"""
Multi-case test for NPU MOE Gating Top-K operator.
Validates compatibility with different input shapes and parameter combinations.
"""
# Test parameter space (aligned with vllm-ascend use cases)
test_configs = {
"group_select_modes": [0, 1],
"renorms": [1],
"norm_types": [0, 1],
"group_counts": [1, 8],
"k_ranges": [4, 8, 12, 16, 6, 32],
"x_dim0": range(1, 17), # Batch size 1-16
"x_dim1": [256, 128, 64, 208, 192, 160] # Expert counts
}
# Generate parameter combinations
param_combinations = itertools.product(
test_configs["group_select_modes"], test_configs["renorms"],
test_configs["norm_types"], test_configs["group_counts"],
test_configs["k_ranges"], test_configs["x_dim0"],
test_configs["x_dim1"])
# Limit test cases to avoid excessive runtime (adjust as needed)
max_test_cases = 100
tested_cases = 0
for params in param_combinations:
if tested_cases >= max_test_cases:
break
(group_select_mode, renorm, norm_type, group_count, k, dim0,
dim1) = params
# Skip invalid configurations
if group_count > 1:
if dim1 % group_count != 0:
continue
if k > (dim1 // group_count):
continue
# Generate random inputs (consistent with vllm-ascend input distribution)
x_np = np.random.uniform(-2.0, 2.0,
(dim0, dim1)).astype(np.float32)
bias_np = np.random.uniform(-2.0, 2.0, (dim1, )).astype(np.float32)
# Convert to torch tensors
x_tensor = torch.tensor(x_np, dtype=torch.float32)
bias_tensor = torch.tensor(bias_np, dtype=torch.float32)
# Random k_group (within valid range)
k_group = random.randint(1, min(group_count, 4))
# Fixed parameters (aligned with NPU OP defaults)
y2_flag = False
routed_scaling_factor = 1.0
eps = 1e-20
try:
# Get NumPy reference result
ref_weights, ref_indices, ref_y2 = self.moe_gating_top_k_np(
x=x_tensor,
k=k,
bias=bias_tensor,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=renorm,
norm_type=norm_type,
y2_flag=y2_flag,
routed_scaling_factor=routed_scaling_factor,
eps=eps)
# Skip if NPU OP is not available
if not hasattr(torch.ops, "_C_ascend") or not hasattr(
torch.ops._C_ascend, "moe_gating_top_k"):
logger.warning(
"NPU MOE gating OP not found, skipping NPU test")
continue
# Get NPU OP result
npu_weights, npu_indices, npu_y2 = torch.ops._C_ascend.moe_gating_top_k(
x=x_tensor.npu(),
k=k,
kGroup=k_group,
groupCount=group_count,
groupSelectMode=group_select_mode,
renorm=renorm,
normType=norm_type,
outFlag=y2_flag,
routedScalingFactor=routed_scaling_factor,
eps=eps,
biasOptional=bias_tensor.npu()
if bias_tensor is not None else None)
# Convert NPU results to CPU for comparison
npu_weights_cpu = npu_weights.cpu()
npu_indices_cpu = npu_indices.cpu().numpy()
# Log test case info (vllm-ascend standard format)
logger.info(
f"Test Case {tested_cases + 1}: "
f"x_shape=({dim0},{dim1}), k={k}, group_count={group_count}, "
f"select_mode={group_select_mode}, norm_type={norm_type}, renorm={renorm}"
)
# Validate results (RTOL=1e-3 is standard for NPU numerical tolerance)
self.assertRtolEqual(ref_weights,
npu_weights_cpu,
rtol=1e-3,
atol=1e-5)
self.assertRtolEqual(ref_indices, npu_indices_cpu)
# Validate y2 if enabled
if y2_flag:
self.assertRtolEqual(ref_y2,
npu_y2.cpu().numpy(),
rtol=1e-3,
atol=1e-5)
tested_cases += 1
logger.info(f"Test Case {tested_cases} passed ")
except Exception as e:
logger.error(f"Test Case failed with error: {str(e)}",
exc_info=True)
continue
logger.info(f"Completed {tested_cases}/{max_test_cases} test cases")
if __name__ == "__main__":
# Run tests with vllm-ascend standard verbosity
run_tests(verbosity=2)

View File

@@ -311,7 +311,7 @@ def test_client_handler_mismatch(server_config):
mismatch_data = {
"label": "JOIN",
"content": {
"device_id": 1, # Mismatched ID
"device_id": 1, # 不匹配的ID
"model_path": "/wrong/model",
"tp": 2,
"pp": 2,

View File

@@ -670,7 +670,7 @@ class TestNPUWorker(TestBase):
(5000, 10000),
]
# Create worker mock
# 创建 worker mock
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.init_npu_memory = 8500

View File

@@ -17,6 +17,7 @@
from typing import Callable, Optional
import torch
import torch_npu
from vllm_ascend.utils import get_weight_prefetch_method
@@ -213,19 +214,21 @@ def _select_experts_with_fusion_ops(
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
_, topk_ids, topk_weights = torch.ops._C_ascend.moe_gating_top_k(
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
kGroup=topk_group,
groupCount=num_expert_group,
groupSelectMode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=1, # 0: softmax->topk(fix); 1: topk->softmax
normType=norm_type, # 0: softmax; 1: sigmoid
outFlag=False, # todo new api; should the third output be output
routedScalingFactor=1,
eps=float(1e-20),
biasOptional=e_score_correction_bias,
)
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids