[Kernel] Add moe_gating_top_k operator support for Ascend NPU (#5579)
### What this PR does / why we need it?
1.replace moe_gating_top_k from torch_npu with custom op
2.enable the renorm function of moe_gating_top_k in softmax scenerio
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
No need test
- vLLM version: v0.13.0
- vLLM main:
7157596103
---------
Signed-off-by: ZCG12345 <2097562023@qq.com>
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -70,6 +70,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"dispatch_layout"
|
||||
"notify_dispatch"
|
||||
"moe_init_routing_custom"
|
||||
"moe_gating_top_k"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
42
csrc/moe_gating_top_k/op_host/CMakeLists.txt
Normal file
42
csrc/moe_gating_top_k/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,42 @@
|
||||
# ----------------------------------------------------------------------------
|
||||
# This program is free software, you can redistribute it and/or modify.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MoeGatingTopK
|
||||
OPTIONS --cce-auto-sync=on
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
# Host (aclnn)
|
||||
if (BUILD_OPEN_PROJECT)
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
moe_gating_top_k_def.cpp
|
||||
|
||||
)
|
||||
|
||||
# Tiling
|
||||
target_sources(optiling PRIVATE
|
||||
moe_gating_top_k_tiling.cpp
|
||||
moe_gating_top_k_tiling_base.cpp
|
||||
moe_gating_top_k_tiling_arch35.cpp
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
moe_gating_top_k_proto.cpp
|
||||
moe_gating_top_k_infershape.cpp
|
||||
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
endif()
|
||||
56
csrc/moe_gating_top_k/op_host/error_log.h
Normal file
56
csrc/moe_gating_top_k/op_host/error_log.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
61
csrc/moe_gating_top_k/op_host/math_util.h
Normal file
61
csrc/moe_gating_top_k/op_host/math_util.h
Normal file
@@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file math_util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef TILING_MATMUL_MATH_UTIL_H
|
||||
#define TILING_MATMUL_MATH_UTIL_H
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
namespace matmul_tiling {
|
||||
class MathUtil {
|
||||
public:
|
||||
static bool IsEqual(float leftValue, float rightValue);
|
||||
template<typename T>
|
||||
static auto CeilDivision(T num1, T num2) -> T
|
||||
{
|
||||
if (num2 == 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
|
||||
static_cast<int64_t>(num2));
|
||||
}
|
||||
template<typename T>
|
||||
static auto Align(T num1, T num2) -> T
|
||||
{
|
||||
return CeilDivision(num1, num2) * num2;
|
||||
}
|
||||
static int32_t AlignDown(int32_t num1, int32_t num2);
|
||||
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
|
||||
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
|
||||
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
|
||||
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||
const int32_t factorEnd);
|
||||
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
||||
const int32_t factorEnd);
|
||||
static bool CheckFactorNumSatisfy(const int32_t dim);
|
||||
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
|
||||
bool isKDim);
|
||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
|
||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
|
||||
const int32_t coreNum, const int32_t maxNum);
|
||||
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
||||
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
|
||||
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
|
||||
};
|
||||
} // namespace matmul_tiling
|
||||
#endif // _MATH_UTIL_H_
|
||||
71
csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp
Normal file
71
csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class MoeGatingTopK : public OpDef {
|
||||
public:
|
||||
explicit MoeGatingTopK(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("bias")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("y")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("expert_idx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("k").Int();
|
||||
this->Attr("k_group").AttrType(OPTIONAL).Int(1);
|
||||
this->Attr("group_count").AttrType(OPTIONAL).Int(1);
|
||||
this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("renorm").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("norm_type").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("out_flag").AttrType(OPTIONAL).Bool(false);
|
||||
this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0);
|
||||
this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f);
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
|
||||
OpAICoreConfig regbaseCfg;
|
||||
regbaseCfg.DynamicCompileStaticFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt");
|
||||
this->AICore().AddConfig("ascend910_95", regbaseCfg);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(MoeGatingTopK);
|
||||
} // namespace ops
|
||||
147
csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp
Normal file
147
csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_gating_top_k_infershape.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "exe_graph/runtime/infer_shape_context.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "error_log.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <string>
|
||||
#define TO_STRING(x) std::string(#x)
|
||||
|
||||
using namespace ge;
|
||||
namespace ops {
|
||||
static constexpr size_t DIM_ONE = 1;
|
||||
static constexpr size_t DIM_TWO = 2;
|
||||
static constexpr int64_t NEG_ONE = -1;
|
||||
static constexpr int64_t X_INDEX = 0;
|
||||
static constexpr int64_t BIAS_INDEX = 1;
|
||||
static constexpr int64_t Y_INDEX = 0;
|
||||
static constexpr int64_t EXPERT_IDX_INDEX = 1;
|
||||
static constexpr int64_t OUT_INDEX = 2;
|
||||
|
||||
static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape)
|
||||
{
|
||||
int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
|
||||
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
|
||||
if (XRows < NEG_ONE || expertNum < NEG_ONE) {
|
||||
OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape,
|
||||
const int64_t k)
|
||||
{
|
||||
if (xShape->GetDimNum() == 1U) {
|
||||
if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) {
|
||||
OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.",
|
||||
TO_STRING(*xShape).c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
} else if (xShape->GetDimNum() != DIM_TWO) {
|
||||
OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.",
|
||||
TO_STRING(*xShape).c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (k < 0) {
|
||||
OP_LOGE(context, "k must be a non-negative number.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k)
|
||||
{
|
||||
OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str());
|
||||
OP_LOGD(context, "k is: %ld.", k);
|
||||
}
|
||||
|
||||
static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape,
|
||||
const gert::Shape *expertIdxShape, const gert::Shape *outShape)
|
||||
{
|
||||
OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str());
|
||||
OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str());
|
||||
OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str());
|
||||
}
|
||||
|
||||
static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context)
|
||||
{
|
||||
OP_LOGD(context, "Begin to do MoeGatingTopKInfershape.");
|
||||
|
||||
// 获取输入shape
|
||||
const gert::Shape *xShape = context->GetInputShape(0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
||||
gert::Shape *yShape = context->GetOutputShape(0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
||||
gert::Shape *expertIdxShape = context->GetOutputShape(1);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape);
|
||||
gert::Shape *outShape = context->GetOutputShape(2);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
|
||||
|
||||
// 获取attr
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, kPtr);
|
||||
const int64_t k = *kPtr;
|
||||
ShowInputShapeInfo(context, xShape, k);
|
||||
|
||||
// 参数校验
|
||||
if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
|
||||
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
|
||||
|
||||
yShape->SetDimNum(DIM_TWO);
|
||||
yShape->SetDim(0U, rows);
|
||||
yShape->SetDim(1U, k);
|
||||
|
||||
expertIdxShape->SetDimNum(DIM_TWO);
|
||||
expertIdxShape->SetDim(0U, rows);
|
||||
expertIdxShape->SetDim(1U, k);
|
||||
|
||||
outShape->SetDimNum(DIM_TWO);
|
||||
outShape->SetDim(0U, rows);
|
||||
outShape->SetDim(1U, expertNum);
|
||||
|
||||
ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape);
|
||||
OP_LOGD(context, "End to do MoeGatingTopKInfershape.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context)
|
||||
{
|
||||
OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType.");
|
||||
auto xDtype = context->GetInputDataType(0);
|
||||
context->SetOutputDataType(Y_INDEX, xDtype);
|
||||
context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32);
|
||||
context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT);
|
||||
OP_LOGD(context, "End to do MoeGatingTopKInferDataType.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK);
|
||||
} // namespace ops
|
||||
15
csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp
Normal file
15
csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_proto.h
|
||||
* \brief
|
||||
*/
|
||||
#include "moe_gating_top_k_proto.h"
|
||||
66
csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h
Normal file
66
csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_proto.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
||||
#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
|
||||
namespace ge {
|
||||
|
||||
/**
|
||||
* @brief Compute renorm(sigmoid) and topk for moe input.
|
||||
*
|
||||
* @par Inputs:
|
||||
* @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count.
|
||||
* @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x.
|
||||
*
|
||||
* @par Outputs:
|
||||
* @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x.
|
||||
The size of the non-1 axis must be the same as that of the corresponding axis of x.
|
||||
The size of the -1 axis must be the same as that of k.
|
||||
* @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y.
|
||||
* @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x.
|
||||
*
|
||||
* @par Attributes:
|
||||
* @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value.
|
||||
* @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1.
|
||||
* @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1.
|
||||
* @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0.
|
||||
* @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0.
|
||||
* @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0.
|
||||
* @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false.
|
||||
* @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0.
|
||||
* @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20.
|
||||
*/
|
||||
REG_OP(MoeGatingTopK)
|
||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
||||
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
||||
.OUTPUT(expert_idx, TensorType({DT_INT32}))
|
||||
.OUTPUT(out, TensorType({DT_FLOAT}))
|
||||
.REQUIRED_ATTR(k, Int)
|
||||
.ATTR(k_group, Int, 1)
|
||||
.ATTR(group_count, Int, 1)
|
||||
.ATTR(group_select_mode, Int, 0)
|
||||
.ATTR(renorm, Int, 0)
|
||||
.ATTR(norm_type, Int, 0)
|
||||
.ATTR(out_flag, Bool, false)
|
||||
.ATTR(routed_scaling_factor, Float, 1.0)
|
||||
.ATTR(eps, Float, 1e-20f)
|
||||
.OP_END_FACTORY_REG(MoeGatingTopK)
|
||||
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
||||
573
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp
Normal file
573
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp
Normal file
@@ -0,0 +1,573 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_gating_top_k_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <cmath>
|
||||
#include "register/op_def_registry.h"
|
||||
#include "exe_graph/runtime/infer_shape_context.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "../tiling_base/tiling_base.h"
|
||||
#include "../tiling_base/tiling_templates_registry.h"
|
||||
#include "platform/platform_info.h"
|
||||
|
||||
|
||||
|
||||
#include "error_log.h"
|
||||
#include "moe_gating_top_k_tiling.h"
|
||||
|
||||
|
||||
#ifndef CEIL_ALIGN
|
||||
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
|
||||
#endif
|
||||
|
||||
#ifndef CEIL_DIV
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
#endif
|
||||
namespace optiling {
|
||||
const static int64_t GROUP_SELECT_MODE_MAX = 0;
|
||||
const static int64_t GROUP_SELECT_MODE_SUM = 1;
|
||||
const static int64_t RENORM_NO = 0;
|
||||
const static int64_t RENORM_L1 = 1;
|
||||
const static int64_t NORM_TYPE_SOFTMAX = 0;
|
||||
const static int64_t NORM_TYPE_SIGMOID = 1;
|
||||
const static int64_t OUT_FLAG_FALSE = 0;
|
||||
const static int64_t OUT_FLAG_TRUE = 1;
|
||||
const static size_t X_INPUT_DIMS = 2;
|
||||
const static size_t BIAS_INPUT_DIMS = 1;
|
||||
const static size_t Y_OUTPUT_DIMS = 2;
|
||||
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
|
||||
const static size_t OUT_OUTPUT_DIMS = 2;
|
||||
const static int64_t MAX_EXPERT_COUNT = 2048;
|
||||
|
||||
const static int64_t X_INPUT_INDEX = 0;
|
||||
const static int64_t BIAS_INPUT_INDEX = 1;
|
||||
const static int64_t Y_OUTPUT_INDEX = 0;
|
||||
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
|
||||
const static int64_t OUT_OUTPUT_INDEX = 2;
|
||||
const static int64_t K_ATTR_INDEX = 0;
|
||||
const static int64_t K_GROUP_ATTR_INDEX = 1;
|
||||
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
|
||||
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
|
||||
const static int64_t RENORM_ATTR_INDEX = 4;
|
||||
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
|
||||
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
|
||||
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
|
||||
const static int64_t EPS_ATTR_INDEX = 8;
|
||||
const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216;
|
||||
const static uint32_t DATATYPESIZE_FLOAT = 4;
|
||||
const static bool IS_LARGEST = true;
|
||||
const static bool IS_INITINDEX = false;
|
||||
const static bool IS_REUSESOURCE = false;
|
||||
const static uint64_t WITH_GROUP_CONDITION = 1;
|
||||
const static uint64_t WITHOUT_GROUP_CONDITION = 2;
|
||||
const static uint64_t MAX_IN_GROUP_CONDITION = 3;
|
||||
constexpr int32_t ROW_COUNT_PER_TASK = 1;
|
||||
|
||||
const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0;
|
||||
const static uint64_t TILING_KEY_WITHOUT_GROUP = 1;
|
||||
const static uint64_t TILING_KEY_GENERALIZED = 2;
|
||||
|
||||
inline static int64_t CeilLog4(int64_t x)
|
||||
{
|
||||
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(4))); // 4 for four
|
||||
}
|
||||
|
||||
class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass {
|
||||
public:
|
||||
explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
|
||||
{
|
||||
Reset();
|
||||
}
|
||||
~MoeGatingTopKTilingBase() override = default;
|
||||
|
||||
void Reset(gert::TilingContext *context) override
|
||||
{
|
||||
TilingBaseClass::Reset(context);
|
||||
Reset();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool IsCapable() override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
ge::graphStatus GetPlatformInfo() override;
|
||||
|
||||
ge::graphStatus GetShapeAttrsInfo() override;
|
||||
|
||||
ge::graphStatus DoOpTiling() override;
|
||||
|
||||
ge::graphStatus DoLibApiTiling() override;
|
||||
|
||||
uint64_t GetTilingKey() const override;
|
||||
|
||||
ge::graphStatus GetWorkspaceSize() override;
|
||||
|
||||
ge::graphStatus PostTiling() override;
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
ge::graphStatus CheckInputShape();
|
||||
ge::graphStatus CheckAttr();
|
||||
ge::graphStatus CheckOutShape();
|
||||
void SplitRows();
|
||||
void CalTmpBufUbSize();
|
||||
|
||||
const gert::Shape *xShape_ = nullptr;
|
||||
const gert::Shape *biasShape_ = nullptr;
|
||||
const gert::Shape *yShape_ = nullptr;
|
||||
const gert::Shape *expertIdxShape_ = nullptr;
|
||||
const gert::Shape *outShape_ = nullptr;
|
||||
|
||||
int64_t rows_ = 0;
|
||||
int64_t expertCount_ = 0;
|
||||
int64_t addBias_ = 0;
|
||||
|
||||
int64_t k_ = 0;
|
||||
int64_t kGroup_ = 0;
|
||||
int64_t groupCount_ = 0;
|
||||
int64_t perGroupExpertCount_ = 0;
|
||||
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
|
||||
int64_t renorm_ = RENORM_NO;
|
||||
int64_t normType_ = NORM_TYPE_SOFTMAX;
|
||||
int64_t outFlag_ = OUT_FLAG_FALSE;
|
||||
float routedScalingFactor_ = 1.0;
|
||||
float eps_ = 1e-20f;
|
||||
|
||||
int64_t inputDtypeSize_;
|
||||
const char *opName_ = "";
|
||||
MoeGatingTopKTilingData moeGatingTopKTilingData_;
|
||||
};
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape()
|
||||
{
|
||||
size_t xDimNum = xShape_->GetDimNum();
|
||||
|
||||
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
|
||||
|
||||
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
|
||||
rows_ = xShape_->GetDim(0);
|
||||
expertCount_ = xShape_->GetDim(1);
|
||||
|
||||
moeGatingTopKTilingData_.set_rowCount(rows_);
|
||||
moeGatingTopKTilingData_.set_expertCount(expertCount_);
|
||||
if (biasShape_ != nullptr) {
|
||||
addBias_ = 1;
|
||||
size_t biasDimNum = biasShape_->GetDimNum();
|
||||
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
|
||||
OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(
|
||||
biasShape_->GetDim(0) != expertCount_,
|
||||
OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
}
|
||||
moeGatingTopKTilingData_.set_addBias(addBias_);
|
||||
|
||||
OP_CHECK_IF(k_ > expertCount_,
|
||||
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::CheckAttr()
|
||||
{
|
||||
OP_CHECK_IF(
|
||||
expertCount_ > MAX_EXPERT_COUNT,
|
||||
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(kGroup_ > groupCount_,
|
||||
OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
|
||||
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
|
||||
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
|
||||
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
|
||||
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1,
|
||||
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
|
||||
OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
perGroupExpertCount_ = expertCount_ / groupCount_;
|
||||
|
||||
OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_);
|
||||
|
||||
OP_CHECK_IF(perGroupExpertCount_ < 1,
|
||||
OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(
|
||||
groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2,
|
||||
OP_LOGE(context_,
|
||||
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
|
||||
perGroupExpertCount_, groupSelectMode_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_,
|
||||
OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L);
|
||||
OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign);
|
||||
if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) {
|
||||
|
||||
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
|
||||
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
|
||||
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_);
|
||||
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo()
|
||||
{
|
||||
opName_ = context_->GetNodeName();
|
||||
OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_);
|
||||
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
|
||||
xShape_ = &xShapePtr->GetStorageShape();
|
||||
OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str());
|
||||
|
||||
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
|
||||
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
|
||||
if (biasShape_ != nullptr) {
|
||||
OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str());
|
||||
}
|
||||
|
||||
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
|
||||
yShape_ = &yShapePtr->GetStorageShape();
|
||||
OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str());
|
||||
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
|
||||
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
|
||||
OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str());
|
||||
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr);
|
||||
outShape_ = &outPtr->GetStorageShape();
|
||||
if (outShape_ != nullptr) {
|
||||
OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str());
|
||||
}
|
||||
|
||||
auto x = context_->GetInputDesc(X_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
|
||||
auto xDtype = x->GetDataType();
|
||||
OP_CHECK_IF(
|
||||
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
|
||||
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
if (biasShapePtr != nullptr) {
|
||||
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
|
||||
OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str());
|
||||
OP_CHECK_IF((biasDtype != xDtype),
|
||||
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
|
||||
auto yDtype = yDesc->GetDataType();
|
||||
OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str());
|
||||
OP_CHECK_IF((yDtype != xDtype),
|
||||
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
|
||||
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
|
||||
auto expertIdDtype = expertIdDesc->GetDataType();
|
||||
OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str());
|
||||
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
|
||||
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc);
|
||||
auto normOutDtype = normOutDesc->GetDataType();
|
||||
OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT),
|
||||
OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
|
||||
auto attrs = context_->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
|
||||
|
||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
|
||||
k_ = *kPtr;
|
||||
OP_LOGI(context_, "Attr k is: %ld", k_);
|
||||
moeGatingTopKTilingData_.set_k(k_);
|
||||
|
||||
OP_LOGI(context_, "Attr k is: %ld ", k_);
|
||||
|
||||
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
|
||||
if (kGroupPtr != nullptr) {
|
||||
kGroup_ = *kGroupPtr;
|
||||
OP_LOGI(context_, "Attr k_group is: %ld", kGroup_);
|
||||
moeGatingTopKTilingData_.set_kGroup(kGroup_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
|
||||
|
||||
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
|
||||
if (groupCountPtr != nullptr) {
|
||||
groupCount_ = *groupCountPtr;
|
||||
OP_LOGI(context_, "Attr group_count is: %ld", groupCount_);
|
||||
moeGatingTopKTilingData_.set_groupCount(groupCount_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
|
||||
|
||||
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
|
||||
if (groupSelectModePtr != nullptr) {
|
||||
groupSelectMode_ = *groupSelectModePtr;
|
||||
OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_);
|
||||
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
|
||||
|
||||
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
|
||||
if (renormPtr != nullptr) {
|
||||
renorm_ = *renormPtr;
|
||||
OP_LOGI(context_, "Attr renorm is: %ld", renorm_);
|
||||
moeGatingTopKTilingData_.set_renorm(renorm_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
|
||||
|
||||
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
|
||||
if (normTypePtr != nullptr) {
|
||||
normType_ = *normTypePtr;
|
||||
OP_LOGI(context_, "Attr norm_type is: %ld", normType_);
|
||||
moeGatingTopKTilingData_.set_normType(normType_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
|
||||
|
||||
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
|
||||
if (outFlagPtr != nullptr) {
|
||||
outFlag_ = (*outFlagPtr) ? 1 : 0;
|
||||
OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_);
|
||||
moeGatingTopKTilingData_.set_outFlag(outFlag_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
|
||||
|
||||
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
|
||||
if (routedScalingFactorPtr != nullptr) {
|
||||
routedScalingFactor_ = *routedScalingFactorPtr;
|
||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_);
|
||||
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
|
||||
|
||||
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
|
||||
if (epsPtr != nullptr) {
|
||||
eps_ = *epsPtr;
|
||||
OP_LOGI(context_, "Attr eps is: %f", eps_);
|
||||
moeGatingTopKTilingData_.set_eps(eps_);
|
||||
}
|
||||
OP_LOGI(context_, "Attr eps is: %f ", eps_);
|
||||
|
||||
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
|
||||
OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo()
|
||||
|
||||
{
|
||||
auto platformInfo = context_->GetPlatformInfo();
|
||||
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
||||
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSizePlatForm;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
|
||||
aicoreParams_.ubSize = ubSizePlatForm;
|
||||
OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape()
|
||||
{
|
||||
OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str());
|
||||
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
|
||||
xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
|
||||
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
|
||||
outShape_->GetDimNum(), xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
|
||||
xShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
|
||||
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outFlag_ && outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
|
||||
outShape_->GetDim(0), outShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
OP_CHECK_IF((yShape_->GetDim(1) != k_),
|
||||
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
|
||||
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outFlag_ && outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
|
||||
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
|
||||
xShape_->GetDim(1)),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingBase::SplitRows()
|
||||
{
|
||||
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
|
||||
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
|
||||
// perCoreRows cannot be 0
|
||||
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
|
||||
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
|
||||
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
|
||||
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
|
||||
int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L));
|
||||
OP_LOGI(context_, "vms count is: %ld", vmsCount);
|
||||
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingBase::CalTmpBufUbSize()
|
||||
|
||||
{
|
||||
|
||||
std::vector<int64_t> shape_vec = {expertCount_};
|
||||
ge::Shape shape(shape_vec);
|
||||
uint32_t maxValue = 0;
|
||||
uint32_t minValue = 0;
|
||||
AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue);
|
||||
|
||||
int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast<int64_t>(sizeof(float));
|
||||
moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast<int64_t>(minValue)));
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling()
|
||||
{
|
||||
|
||||
OP_LOGI(context_, "DoOpTiling: start");
|
||||
auto ret = CheckInputShape();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CheckOutShape();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CheckAttr();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
CalTmpBufUbSize();
|
||||
SplitRows();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling()
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize()
|
||||
{
|
||||
|
||||
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingBase::PostTiling()
|
||||
{
|
||||
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
|
||||
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = workspaceSize_;
|
||||
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
|
||||
context_->GetRawTilingData()->GetCapacity());
|
||||
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t MoeGatingTopKTilingBase::GetTilingKey() const
|
||||
{
|
||||
|
||||
if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ &&
|
||||
groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID &&
|
||||
!outFlag_) {
|
||||
|
||||
return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF;
|
||||
} else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) {
|
||||
return TILING_KEY_WITHOUT_GROUP;
|
||||
} else {
|
||||
return TILING_KEY_GENERALIZED;
|
||||
}
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingBase::Reset()
|
||||
{
|
||||
opName_ = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000);
|
||||
} // namespace optiling
|
||||
86
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h
Normal file
86
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
||||
#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
#include "../tiling_base/tiling_base.h"
|
||||
#include "../tiling_base/tiling_templates_registry.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "error_log.h"
|
||||
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "platform/platform_infos_def.h"
|
||||
#include "math_util.h"
|
||||
//#include "util/extern_math_util.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData)
|
||||
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
|
||||
TILING_DATA_FIELD_DEF(int64_t, rowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, expertCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, addBias);
|
||||
TILING_DATA_FIELD_DEF(int64_t, k);
|
||||
TILING_DATA_FIELD_DEF(int64_t, kGroup);
|
||||
TILING_DATA_FIELD_DEF(int64_t, groupCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
|
||||
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
|
||||
TILING_DATA_FIELD_DEF(int64_t, renorm);
|
||||
TILING_DATA_FIELD_DEF(int64_t, normType);
|
||||
TILING_DATA_FIELD_DEF(int64_t, outFlag);
|
||||
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
|
||||
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
|
||||
TILING_DATA_FIELD_DEF(float, eps);
|
||||
TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData)
|
||||
|
||||
BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData)
|
||||
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
|
||||
TILING_DATA_FIELD_DEF(int64_t, rowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, expertCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, addBias);
|
||||
TILING_DATA_FIELD_DEF(int64_t, k);
|
||||
TILING_DATA_FIELD_DEF(int64_t, kGroup);
|
||||
TILING_DATA_FIELD_DEF(int64_t, groupCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
|
||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
|
||||
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
|
||||
TILING_DATA_FIELD_DEF(int64_t, renorm);
|
||||
TILING_DATA_FIELD_DEF(int64_t, normType);
|
||||
TILING_DATA_FIELD_DEF(int64_t, outFlag);
|
||||
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
|
||||
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
|
||||
TILING_DATA_FIELD_DEF(float, eps);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData)
|
||||
struct MoeGatingTopKCompileInfo {};
|
||||
} // namespace optiling
|
||||
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
||||
521
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp
Normal file
521
csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp
Normal file
@@ -0,0 +1,521 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_gating_top_k_tiling_arch35.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "error_log.h"
|
||||
#include "moe_gating_top_k_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "platform/platform_info.h"
|
||||
#include "../tiling_base/tiling_base.h"
|
||||
#include "../tiling_base/tiling_templates_registry.h"
|
||||
|
||||
#ifndef CEIL_ALIGN
|
||||
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
|
||||
#endif
|
||||
#ifndef CEIL_DIV
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
#endif
|
||||
namespace optiling {
|
||||
const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000;
|
||||
|
||||
const static int64_t GROUP_SELECT_MODE_MAX = 0;
|
||||
const static int64_t GROUP_SELECT_MODE_SUM = 1;
|
||||
const static int64_t RENORM_NO = 0;
|
||||
const static int64_t RENORM_L1 = 1;
|
||||
const static int64_t NORM_TYPE_SOFTMAX = 0;
|
||||
const static int64_t NORM_TYPE_SIGMOID = 1;
|
||||
const static int64_t OUT_FLAG_FALSE = 0;
|
||||
const static int64_t OUT_FLAG_TRUE = 1;
|
||||
const static size_t X_INPUT_DIMS = 2;
|
||||
const static size_t BIAS_INPUT_DIMS = 1;
|
||||
const static size_t Y_OUTPUT_DIMS = 2;
|
||||
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
|
||||
const static size_t OUT_OUTPUT_DIMS = 2;
|
||||
const static int64_t MAX_EXPERT_COUNT = 2048;
|
||||
|
||||
const static int64_t X_INPUT_INDEX = 0;
|
||||
const static int64_t BIAS_INPUT_INDEX = 1;
|
||||
const static int64_t Y_OUTPUT_INDEX = 0;
|
||||
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
|
||||
const static int64_t OUT_OUTPUT_INDEX = 2;
|
||||
const static int64_t K_ATTR_INDEX = 0;
|
||||
const static int64_t K_GROUP_ATTR_INDEX = 1;
|
||||
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
|
||||
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
|
||||
const static int64_t RENORM_ATTR_INDEX = 4;
|
||||
const static int64_t MRGSORT_SIZE = 4;
|
||||
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
|
||||
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
|
||||
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
|
||||
const static int64_t EPS_ATTR_INDEX = 8;
|
||||
const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast<int64_t>(16 * 1024 * 1024); // 预留16M空间
|
||||
|
||||
|
||||
class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass {
|
||||
public:
|
||||
explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
|
||||
{
|
||||
Reset();
|
||||
}
|
||||
~MoeGatingTopKTilingRegbase() override = default;
|
||||
|
||||
void Reset(gert::TilingContext *context) override
|
||||
{
|
||||
TilingBaseClass::Reset(context);
|
||||
Reset();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool IsCapable() override
|
||||
{
|
||||
if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ge::graphStatus GetPlatformInfo() override;
|
||||
|
||||
ge::graphStatus GetShapeAttrsInfo() override;
|
||||
|
||||
ge::graphStatus DoOpTiling() override;
|
||||
|
||||
ge::graphStatus DoLibApiTiling() override;
|
||||
|
||||
uint64_t GetTilingKey() const override;
|
||||
|
||||
ge::graphStatus GetWorkspaceSize() override;
|
||||
|
||||
ge::graphStatus PostTiling() override;
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
ge::graphStatus CheckInputShape();
|
||||
ge::graphStatus CheckAttr();
|
||||
ge::graphStatus CheckOutShape();
|
||||
void CalTmpBufUbSize();
|
||||
void SplitRows();
|
||||
void Tiling4GatherOutComputeSplitK();
|
||||
|
||||
const gert::Shape *xShape_ = nullptr;
|
||||
const gert::Shape *biasShape_ = nullptr;
|
||||
const gert::Shape *yShape_ = nullptr;
|
||||
const gert::Shape *expertIdxShape_ = nullptr;
|
||||
const gert::Shape *outShape_ = nullptr;
|
||||
|
||||
int64_t rows_;
|
||||
int64_t expertCount_;
|
||||
int64_t addBias_ = 0;
|
||||
|
||||
int64_t k_;
|
||||
int64_t kGroup_ = 1;
|
||||
int64_t groupCount_ = 1;
|
||||
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
|
||||
int64_t renorm_ = RENORM_NO;
|
||||
int64_t normType_ = NORM_TYPE_SOFTMAX;
|
||||
int64_t outFlag_ = OUT_FLAG_FALSE;
|
||||
float routedScalingFactor_ = 1.0;
|
||||
float eps_ = 1e-20f;
|
||||
|
||||
int64_t inputDtypeSize_;
|
||||
const char *opName_ = "";
|
||||
MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_;
|
||||
platform_ascendc::SocVersion socVersion;
|
||||
};
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape()
|
||||
{
|
||||
size_t xDimNum = xShape_->GetDimNum();
|
||||
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
|
||||
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
|
||||
rows_ = xShape_->GetDim(0);
|
||||
expertCount_ = xShape_->GetDim(1);
|
||||
moeGatingTopKTilingData_.set_rowCount(rows_);
|
||||
moeGatingTopKTilingData_.set_expertCount(expertCount_);
|
||||
OP_CHECK_IF(
|
||||
expertCount_ > MAX_EXPERT_COUNT,
|
||||
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
if (biasShape_ != nullptr) {
|
||||
addBias_ = 1;
|
||||
size_t biasDimNum = biasShape_->GetDimNum();
|
||||
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
|
||||
OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_,
|
||||
OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.",
|
||||
biasShape_->GetDim(0), expertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
moeGatingTopKTilingData_.set_addBias(addBias_);
|
||||
|
||||
OP_CHECK_IF(k_ > expertCount_,
|
||||
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr()
|
||||
{
|
||||
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
|
||||
OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(kGroup_ > groupCount_,
|
||||
OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_,
|
||||
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
|
||||
kGroup_, k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
if (kGroup_ == groupCount_ || groupCount_ == expertCount_) {
|
||||
kGroup_ = 1;
|
||||
groupCount_ = 1;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_kGroup(kGroup_);
|
||||
moeGatingTopKTilingData_.set_groupCount(groupCount_);
|
||||
int64_t groupExpertCount = expertCount_ / groupCount_;
|
||||
int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L);
|
||||
moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_);
|
||||
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
|
||||
|
||||
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
|
||||
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
|
||||
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(kGroup_ * groupExpertCount < k_,
|
||||
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
|
||||
kGroup_ * groupExpertCount, k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(groupExpertCount < 1,
|
||||
OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(
|
||||
groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
|
||||
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
|
||||
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2,
|
||||
OP_LOGE(context_,
|
||||
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
|
||||
groupExpertCount, groupSelectMode_),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(renorm_ != RENORM_NO,
|
||||
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
|
||||
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
|
||||
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo()
|
||||
{
|
||||
opName_ = context_->GetNodeName();
|
||||
|
||||
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
|
||||
xShape_ = &xShapePtr->GetStorageShape();
|
||||
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
|
||||
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
|
||||
|
||||
|
||||
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
|
||||
yShape_ = &yShapePtr->GetStorageShape();
|
||||
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
|
||||
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
|
||||
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
|
||||
if (outPtr != nullptr) {
|
||||
outShape_ = &outPtr->GetStorageShape();
|
||||
}
|
||||
|
||||
auto x = context_->GetInputDesc(X_INPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
|
||||
auto xDtype = x->GetDataType();
|
||||
OP_CHECK_IF(
|
||||
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
|
||||
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
if (biasShapePtr != nullptr) {
|
||||
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
|
||||
OP_CHECK_IF((biasDtype != xDtype),
|
||||
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
|
||||
auto yDtype = yDesc->GetDataType();
|
||||
OP_CHECK_IF((yDtype != xDtype),
|
||||
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
|
||||
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
|
||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
|
||||
auto expertIdDtype = expertIdDesc->GetDataType();
|
||||
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
|
||||
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
|
||||
auto attrs = context_->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
|
||||
|
||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
|
||||
k_ = *kPtr;
|
||||
moeGatingTopKTilingData_.set_k(k_);
|
||||
OP_LOGI(context_, "Attr k is: %ld ", k_);
|
||||
|
||||
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
|
||||
if (kGroupPtr != nullptr) {
|
||||
kGroup_ = *kGroupPtr;
|
||||
}
|
||||
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
|
||||
|
||||
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
|
||||
if (groupCountPtr != nullptr) {
|
||||
groupCount_ = *groupCountPtr;
|
||||
}
|
||||
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
|
||||
|
||||
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
|
||||
if (groupSelectModePtr != nullptr) {
|
||||
groupSelectMode_ = *groupSelectModePtr;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
|
||||
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
|
||||
|
||||
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
|
||||
if (renormPtr != nullptr) {
|
||||
renorm_ = *renormPtr;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_renorm(renorm_);
|
||||
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
|
||||
|
||||
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
|
||||
if (normTypePtr != nullptr) {
|
||||
normType_ = *normTypePtr;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_normType(normType_);
|
||||
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
|
||||
|
||||
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
|
||||
if (outFlagPtr != nullptr) {
|
||||
outFlag_ = (*outFlagPtr) ? 1 : 0;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_outFlag(outFlag_);
|
||||
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
|
||||
|
||||
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
|
||||
if (routedScalingFactorPtr != nullptr) {
|
||||
routedScalingFactor_ = *routedScalingFactorPtr;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
|
||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
|
||||
|
||||
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
|
||||
if (epsPtr != nullptr) {
|
||||
eps_ = *epsPtr;
|
||||
}
|
||||
moeGatingTopKTilingData_.set_eps(eps_);
|
||||
OP_LOGI(context_, "Attr eps is: %f ", eps_);
|
||||
|
||||
auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
|
||||
if (outFlag_ && outDesc != nullptr) {
|
||||
auto outDtype = outDesc->GetDataType();
|
||||
OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT),
|
||||
OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.",
|
||||
ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo()
|
||||
{
|
||||
auto platformInfo = context_->GetPlatformInfo();
|
||||
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
||||
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
|
||||
socVersion = ascendcPlatform.GetSocVersion();
|
||||
uint64_t ubSizePlatForm;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
|
||||
aicoreParams_.ubSize = ubSizePlatForm;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape()
|
||||
{
|
||||
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
|
||||
xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
|
||||
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
|
||||
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
|
||||
outShape_->GetDimNum(), xShape_->GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
|
||||
xShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
|
||||
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outFlag_ && outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
|
||||
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
|
||||
outShape_->GetDim(0), outShape_->GetDim(0)),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
|
||||
OP_CHECK_IF((yShape_->GetDim(1) != k_),
|
||||
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
|
||||
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (outFlag_ && outShape_ != nullptr) {
|
||||
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
|
||||
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
|
||||
xShape_->GetDim(1)),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() {
|
||||
std::vector<int64_t> shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()};
|
||||
ge::Shape softmaxShape(shape_vec);
|
||||
|
||||
uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true);
|
||||
AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData);
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingRegbase::SplitRows()
|
||||
{
|
||||
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
|
||||
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
|
||||
if (perCoreRows == 0) {
|
||||
OP_LOGE(context_, "perCoreRows can't be 0.");
|
||||
return;
|
||||
}
|
||||
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
|
||||
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
|
||||
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
|
||||
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
|
||||
|
||||
int64_t vmsCount = 0;
|
||||
if (kGroup_ > MRGSORT_SIZE) {
|
||||
int64_t index = MRGSORT_SIZE;
|
||||
while (index < kGroup_) {
|
||||
index = index * MRGSORT_SIZE;
|
||||
vmsCount++;
|
||||
}
|
||||
}
|
||||
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling()
|
||||
{
|
||||
auto ret = CheckInputShape();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CheckAttr();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CheckOutShape();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
CalTmpBufUbSize();
|
||||
SplitRows();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling()
|
||||
{
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize()
|
||||
{
|
||||
|
||||
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling()
|
||||
{
|
||||
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
|
||||
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = workspaceSize_;
|
||||
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
|
||||
context_->GetRawTilingData()->GetCapacity());
|
||||
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const
|
||||
{
|
||||
return MOE_GATING_TOP_K_REGBASE_TILING_KEY;
|
||||
}
|
||||
|
||||
void MoeGatingTopKTilingRegbase::Reset()
|
||||
{
|
||||
opName_ = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000);
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_gating_top_k_tiling_base.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "moe_gating_top_k_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../tiling_base/tiling_base.h"
|
||||
#include "../tiling_base/tiling_templates_registry.h"
|
||||
#include "error_log.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
namespace optiling {
|
||||
static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context)
|
||||
{
|
||||
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(MoeGatingTopK)
|
||||
.Tiling(TilingForMoeGatingTopK)
|
||||
.TilingParse<MoeGatingTopKCompileInfo>(TilingPrepareForMoeGatingTopK);
|
||||
|
||||
} // namespace optiling
|
||||
89
csrc/moe_gating_top_k/op_kernel/common.h
Normal file
89
csrc/moe_gating_top_k/op_kernel/common.h
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file common.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_GATING_TOP_K_COMMON_H
|
||||
#define MOE_GATING_TOP_K_COMMON_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeGatingTopK {
|
||||
using namespace AscendC;
|
||||
const float MIN_FP32 = *(float *)(&F32_NEG_INF);
|
||||
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
|
||||
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
|
||||
constexpr int64_t BLOCK_BYTES = 32;
|
||||
constexpr int64_t REPEAT_BYTES = 256;
|
||||
constexpr int64_t REPEAT_BLOCKS = 8;
|
||||
|
||||
constexpr int32_t CONSTANT_TWO = 2;
|
||||
constexpr int32_t CONSTANT_THREE = 3;
|
||||
constexpr int32_t CONSTANT_FOUR = 4;
|
||||
constexpr int32_t CONSTANT_EIGHT = 8;
|
||||
|
||||
constexpr int64_t MERGE_LIST_TWO = 2;
|
||||
constexpr int64_t MERGE_LIST_THREE = 3;
|
||||
constexpr int64_t MERGE_LIST_FOUR = 4;
|
||||
|
||||
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
|
||||
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
|
||||
|
||||
constexpr int64_t NORM_TYPE_SOFTMAX = 0;
|
||||
constexpr int64_t NORM_TYPE_SIGMOID = 1;
|
||||
|
||||
__aicore__ inline int64_t Ceil(int64_t a, int64_t b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes)
|
||||
{
|
||||
if (bytes == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes)
|
||||
{
|
||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(T a, T b)
|
||||
{
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Max(T a, T b)
|
||||
{
|
||||
return a < b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 CeilDiv(T1 x, T2 y)
|
||||
{
|
||||
if (y != 0 && x != 0) {
|
||||
const T1 quotient = x / y;
|
||||
return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient;
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
} // namespace MoeGatingTopK
|
||||
#endif // MOE_GATING_TOP_K_COMMON_H
|
||||
55
csrc/moe_gating_top_k/op_kernel/error_log.h
Normal file
55
csrc/moe_gating_top_k/op_kernel/error_log.h
Normal file
@@ -0,0 +1,55 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
63
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp
Normal file
63
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "moe_gating_top_k_e_k_fullload.h"
|
||||
#include "moe_gating_top_k_without_group.h"
|
||||
#include "moe_gating_top_k_generalized.h"
|
||||
#include "error_log.h"
|
||||
|
||||
#define TILING_KEY_PER_GROUP_COUNT_32 0
|
||||
#define TILING_KEY_WITHOUT_GROUP 1
|
||||
#define TILING_KEY_GENERALIZED 2
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace MoeGatingTopK;
|
||||
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
||||
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
|
||||
if (g_coreType == AIC) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling);
|
||||
if (workspace == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
GM_ADDR userWS = GetUserWorkspace(workspace);
|
||||
if (userWS == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const MoeGatingTopKTilingData *__restrict t = &tilingData;
|
||||
TPipe tPipe;
|
||||
if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) {
|
||||
MoeGatingTopKEKFullload<DTYPE_X> op;
|
||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) {
|
||||
MoeGatingTopKWithoutGroup<DTYPE_X> op;
|
||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) {
|
||||
MoeGatingTopKGenerlized<DTYPE_X> op;
|
||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
||||
op.Process();
|
||||
}
|
||||
|
||||
}
|
||||
46
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp
Normal file
46
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp
Normal file
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_apt.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "arch35/moe_gating_top_k_regbase.h"
|
||||
using namespace AscendC;
|
||||
using namespace MoeGatingTopK;
|
||||
|
||||
#define TILING_KEY_REGBASE 10000
|
||||
|
||||
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
||||
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
if (g_coreType == AIC) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (workspace == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
GM_ADDR userWS = GetUserWorkspace(workspace);
|
||||
if (userWS == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling);
|
||||
const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in;
|
||||
TPipe tPipe;
|
||||
if (TILING_KEY_IS(TILING_KEY_REGBASE)) {
|
||||
MoeGatingTopKRegbase<DTYPE_X> op;
|
||||
op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
404
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h
Normal file
404
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h
Normal file
@@ -0,0 +1,404 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_e_k_fullload.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
||||
#define MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
||||
#include "kernel_operator.h"
|
||||
#include "common.h"
|
||||
namespace MoeGatingTopK {
|
||||
using namespace AscendC;
|
||||
|
||||
template <typename T>
|
||||
class MoeGatingTopKEKFullload {
|
||||
public:
|
||||
__aicore__ inline MoeGatingTopKEKFullload(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInBias();
|
||||
__aicore__ inline void CopyInX(int64_t progress);
|
||||
__aicore__ inline void ComputeX();
|
||||
__aicore__ inline void SortInGroup();
|
||||
__aicore__ inline void SelectTopKGroupIndex();
|
||||
__aicore__ inline void SelectTopKExpertIdx();
|
||||
__aicore__ inline void SelectTopKExpertScore();
|
||||
__aicore__ inline void CopyOut(int64_t progress);
|
||||
|
||||
private:
|
||||
TPipe *pipe_;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
||||
TBuf<TPosition::VECCALC> biasInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
||||
|
||||
TQue<QuePosition::VECOUT, 1> xBiasQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> xSigmoidQueue_;
|
||||
TQue<QuePosition::VECIN, 1> sigmoidTmpQueue_;
|
||||
TQue<QuePosition::VECIN, 1> sortedInGroupQueue_;
|
||||
TQue<QuePosition::VECIN, 1> sortedGroupQueue_;
|
||||
TBuf<TPosition::VECCALC> calcTmpBuffer_;
|
||||
|
||||
GlobalTensor<T> xGm_;
|
||||
GlobalTensor<T> biasGm_;
|
||||
GlobalTensor<T> yGm_;
|
||||
GlobalTensor<int32_t> expertIdxGm_;
|
||||
GlobalTensor<T> outGm_;
|
||||
|
||||
int64_t blockIdx_;
|
||||
int64_t perCoreRowCount_;
|
||||
int64_t curCoreRowCount_;
|
||||
int64_t expertCount_;
|
||||
bool addBias_;
|
||||
int64_t k_;
|
||||
int64_t kGroup_;
|
||||
int64_t groupCount_;
|
||||
int64_t groupSelectMode_;
|
||||
int64_t renorm_;
|
||||
int64_t normType_;
|
||||
int64_t outFlag_;
|
||||
float routedScalingFactor_;
|
||||
float eps_;
|
||||
|
||||
int64_t expertCountAlign_;
|
||||
int64_t kAlign_;
|
||||
int64_t perGroupExpertCount_;
|
||||
|
||||
const MoeGatingTopKTilingData *tilingData_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInBias()
|
||||
{
|
||||
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
} else {
|
||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE, expertCount_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInX(int64_t row)
|
||||
{
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
||||
} else {
|
||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
||||
dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
||||
expertCount_);
|
||||
}
|
||||
|
||||
xInQueue_.EnQue(xInLocalTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::ComputeX()
|
||||
{
|
||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
||||
LocalTensor<float> xBiasTensor = xBiasQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
|
||||
LocalTensor<uint8_t> sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor<uint8_t>(); // 临时空间可以复用
|
||||
Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (addBias_) {
|
||||
Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_);
|
||||
} else {
|
||||
Adds(xBiasTensor, xSigmoidTensor, static_cast<float>(0), expertCount_);
|
||||
}
|
||||
|
||||
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
|
||||
xBiasQueue_.EnQue<float>(xBiasTensor);
|
||||
xInQueue_.FreeTensor(xInLocalTensor);
|
||||
sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SortInGroup()
|
||||
{
|
||||
LocalTensor<float> xBiasTensor = xBiasQueue_.DeQue<float>();
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.AllocTensor<float>(); // 组内排序的结果, 后续归并需要
|
||||
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>(); // 用于存储排序时的索引
|
||||
ArithProgression(indexTensor.ReinterpretCast<int32_t>(), 0, 1, expertCount_); // 生成组索引0 1 2 ......
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序
|
||||
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
|
||||
xBiasQueue_.FreeTensor(xBiasTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKGroupIndex()
|
||||
{
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
|
||||
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>();
|
||||
LocalTensor<float> top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor<float>(); // 这个临时空间可以复用
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
|
||||
indexTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
|
||||
indexTensor.SetValue(1, static_cast<uint32_t>(0));
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
|
||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
||||
GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = 8;
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = 8;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast<uint32_t>(64),
|
||||
gatherMaskParams, rsvdCnt);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> groupTop2SumTensor = top2ValueInGroupTensor;
|
||||
PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1,
|
||||
1); // 计算每个组内最大的两个数之和
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
LocalTensor<uint32_t> groupIndexTensor = indexTensor;
|
||||
ArithProgression(groupIndexTensor.ReinterpretCast<int32_t>(), 0, 1, groupCount_); // 生成组索引
|
||||
PipeBarrier<PIPE_V>();
|
||||
// 用最小值补到32个数
|
||||
int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_;
|
||||
if (duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX << groupCount_;
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
// 排序,将kgroup选出来
|
||||
LocalTensor<float> sortedGroupTensor = sortedGroupQueue_.AllocTensor<float>();
|
||||
Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1);
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<int32_t> sortedGroupIndexTensor = indexTensor.ReinterpretCast<int32_t>();
|
||||
// 提取组序号
|
||||
uint8_t src1Pattern = 2; // 内置固定模式
|
||||
GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
||||
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
|
||||
|
||||
// 需要将组排序(这里是降序,所以下mrgsor的时候反着取,3、2、1、0)
|
||||
Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_;
|
||||
if (duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX << kGroup_;
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast<uint32_t>(), 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
src1Pattern = 1;
|
||||
GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast<uint32_t>(0), {1, 1, 0, 0},
|
||||
rsvdCnt);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_);
|
||||
|
||||
sortedGroupQueue_.FreeTensor(sortedGroupTensor);
|
||||
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
|
||||
sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertIdx()
|
||||
{
|
||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.AllocTensor<int32_t>();
|
||||
LocalTensor<int32_t> topKGroupIndexTensor = calcTmpBuffer_.Get<int32_t>();
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
|
||||
LocalTensor<float> sortedExpertTensor = xInQueue_.AllocTensor<float>();
|
||||
AscendC::MrgSort4Info params;
|
||||
params.elementLengths[0] = k_;
|
||||
params.elementLengths[1] = k_;
|
||||
params.elementLengths[2] = k_;
|
||||
params.elementLengths[3] = k_;
|
||||
params.ifExhaustedSuspension = true;
|
||||
params.validBit = 0b1111;
|
||||
params.repeatTimes = 1;
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2;
|
||||
int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2;
|
||||
int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2;
|
||||
int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2;
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
AscendC::MrgSortSrcList<float> srcList;
|
||||
srcList.src1 = sortedInGroupTensor[listOffset1];
|
||||
srcList.src2 = sortedInGroupTensor[listOffset2];
|
||||
srcList.src3 = sortedInGroupTensor[listOffset3];
|
||||
srcList.src4 = sortedInGroupTensor[listOffset4];
|
||||
MrgSort<float>(sortedExpertTensor, srcList, params);
|
||||
PipeBarrier<PIPE_V>();
|
||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
||||
uint8_t src1Pattern = 2; // 内置固定模式
|
||||
GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
||||
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
|
||||
xInQueue_.FreeTensor(sortedExpertTensor);
|
||||
expertIdxOutQueue_.EnQue(expertIdxTensor);
|
||||
sortedInGroupQueue_.FreeTensor(sortedInGroupTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertScore()
|
||||
{
|
||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertByteIdxTensor = calcTmpBuffer_.Get<int32_t>();
|
||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
|
||||
LocalTensor<T> yTensor = yOutQueue_.AllocTensor<T>();
|
||||
LocalTensor<float> yOutTensor;
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
yOutTensor = yTensor.template ReinterpretCast<float>()[kAlign_];
|
||||
} else {
|
||||
yOutTensor = yTensor;
|
||||
}
|
||||
Muls(expertByteIdxTensor, expertIdxTensor, static_cast<int32_t>(sizeof(float)), k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast<uint32_t>(),
|
||||
static_cast<uint32_t>(0), k_);
|
||||
|
||||
LocalTensor<float> calTensor = calcTmpBuffer_.Get<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_);
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float sumValue = calTensor.GetValue(0) + eps_;
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Duplicate(calTensor, sumValue, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(yOutTensor, yOutTensor, calTensor, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_);
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_);
|
||||
}
|
||||
|
||||
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
|
||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxTensor);
|
||||
yOutQueue_.EnQue(yTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyOut(int64_t row)
|
||||
{
|
||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
|
||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams);
|
||||
xSigmoidQueue_.FreeTensor(xSigmoidTensor);
|
||||
expertIdxOutQueue_.FreeTensor(expertIdxTensor);
|
||||
yOutQueue_.FreeTensor(yOutTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
||||
GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
||||
{
|
||||
tilingData_ = tilingData;
|
||||
pipe_ = tPipe;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
if (blockIdx_ == GetBlockNum() - 1) {
|
||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
||||
} else {
|
||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
}
|
||||
expertCount_ = tilingData_->expertCount;
|
||||
addBias_ = tilingData_->addBias == 1;
|
||||
k_ = tilingData_->k;
|
||||
kGroup_ = tilingData_->kGroup;
|
||||
groupCount_ = tilingData_->groupCount;
|
||||
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
|
||||
routedScalingFactor_ = tilingData_->routedScalingFactor;
|
||||
eps_ = tilingData_->eps;
|
||||
|
||||
expertCountAlign_ = Align(expertCount_, sizeof(float));
|
||||
kAlign_ = Align(expertCount_, sizeof(float));
|
||||
|
||||
// init input gm buf
|
||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
||||
|
||||
// init output gm buf
|
||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
|
||||
// init que
|
||||
pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
|
||||
pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float)));
|
||||
pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
||||
|
||||
pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t)));
|
||||
pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
||||
|
||||
pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
||||
pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2);
|
||||
pipe_->InitBuffer(sortedGroupQueue_, 2,
|
||||
(groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM *
|
||||
sizeof(float) * 2);
|
||||
|
||||
pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::Process()
|
||||
{
|
||||
CopyInBias();
|
||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
||||
CopyInX(row);
|
||||
ComputeX();
|
||||
SortInGroup();
|
||||
SelectTopKGroupIndex();
|
||||
SelectTopKExpertIdx();
|
||||
SelectTopKExpertScore();
|
||||
CopyOut(row);
|
||||
}
|
||||
}
|
||||
} // namespace MoeGatingTopK
|
||||
#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
||||
669
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h
Normal file
669
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h
Normal file
@@ -0,0 +1,669 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_generalized.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
||||
#define MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
||||
#include "kernel_operator.h"
|
||||
#include "common.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace MoeGatingTopK {
|
||||
using namespace AscendC;
|
||||
|
||||
template <typename T>
|
||||
class MoeGatingTopKGenerlized {
|
||||
public:
|
||||
__aicore__ inline MoeGatingTopKGenerlized(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInBiasAndInitExpertId();
|
||||
__aicore__ inline void CopyInX(int64_t progress);
|
||||
__aicore__ inline void ComputeX();
|
||||
__aicore__ inline void CopuOutXNorm(int64_t row);
|
||||
__aicore__ inline void SortInGroup();
|
||||
__aicore__ inline void SelectTopKGroupIndex();
|
||||
__aicore__ inline void SelectTopKExpertIdx();
|
||||
__aicore__ inline void SelectTopKExpertScore();
|
||||
__aicore__ inline void CumputeActualTopKExpertId();
|
||||
__aicore__ inline void CopyOut(int64_t row);
|
||||
|
||||
private:
|
||||
TPipe *pipe_;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
||||
|
||||
TBuf<TPosition::VECCALC> biasBuf_; // Store input bias
|
||||
TBuf<TPosition::VECCALC> expertIdBuf_; // Expert ID
|
||||
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // Store value after adding bias
|
||||
TBuf<TPosition::VECCALC> xNormBuf_; // Store value after computing sigmoid or softmax
|
||||
TBuf<TPosition::VECCALC> sortedInGroupBuf_; // Store sorted results within groups
|
||||
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
|
||||
TBuf<TPosition::VECCALC> sortedGroupIndexBuf_;
|
||||
TBuf<TPosition::VECCALC> calcTmpBuf_;
|
||||
|
||||
GlobalTensor<T> xGm_;
|
||||
GlobalTensor<T> biasGm_;
|
||||
GlobalTensor<T> yGm_;
|
||||
GlobalTensor<int32_t> expertIdxGm_;
|
||||
GlobalTensor<float> outGm_;
|
||||
|
||||
int64_t blockIdx_ = 0;
|
||||
int64_t perCoreRowCount_ = 0;
|
||||
int64_t curCoreRowCount_ = 0;
|
||||
int64_t expertCount_ = 0;
|
||||
bool addBias_ = false;
|
||||
int64_t k_ = 0;
|
||||
int64_t kGroup_ = 0;
|
||||
int64_t groupCount_ = 0;
|
||||
int64_t groupCountAlign_ = 0;
|
||||
int64_t perGroupExpertCount_ = 0;
|
||||
int64_t perGroupExpertCountAlign_ = 0;
|
||||
int64_t groupSelectMode_ = 0;
|
||||
int64_t renorm_ = 0;
|
||||
int64_t normType_ = 0;
|
||||
int64_t outFlag_ = 0;
|
||||
|
||||
int64_t expertCountAlign_ = 0;
|
||||
int64_t kAlign_ = 0;
|
||||
bool isAlign_ = false;
|
||||
|
||||
const MoeGatingTopKTilingData *tilingData_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInBiasAndInitExpertId()
|
||||
{
|
||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
||||
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
|
||||
DataCopyExtParams dataCopyParams;
|
||||
dataCopyParams.blockCount = groupCount_;
|
||||
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
|
||||
dataCopyParams.srcStride = 0;
|
||||
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
|
||||
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if (addBias_) {
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
} else {
|
||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
||||
expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
if (!isAlign_) {
|
||||
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
|
||||
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
|
||||
if (duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(biasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
|
||||
perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
|
||||
}
|
||||
}
|
||||
}
|
||||
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCountAlign_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInX(int64_t row)
|
||||
{
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
||||
DataCopyExtParams dataCopyParams;
|
||||
dataCopyParams.blockCount = groupCount_;
|
||||
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
|
||||
dataCopyParams.srcStride = 0;
|
||||
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
|
||||
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
||||
} else {
|
||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
||||
dataCopyPadParams);
|
||||
}
|
||||
xInQueue_.EnQue(xInLocalTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::ComputeX()
|
||||
{
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
||||
expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
|
||||
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
|
||||
if (!isAlign_ && duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(xInLocalTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
|
||||
(perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (normType_ == 1) { // sigmoid
|
||||
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
|
||||
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
else if (normType_ == 0) { // softmax
|
||||
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
|
||||
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
|
||||
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_);
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float maxValue = reduceValueTensor.GetValue(0);
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(xNormTensor, xNormTensor, expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_);
|
||||
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float sumValue = reduceValueTensor.GetValue(0);
|
||||
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (addBias_) {
|
||||
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_);
|
||||
} else {
|
||||
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
|
||||
}
|
||||
|
||||
if (!isAlign_ && duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex],
|
||||
FLOAT32_NEG_INF, // MIN_FP32,
|
||||
mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
|
||||
}
|
||||
xInQueue_.FreeTensor(xInLocalTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopuOutXNorm(int64_t row)
|
||||
{
|
||||
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
|
||||
outOutQueue_.EnQue<float>(outOutTensor);
|
||||
outOutTensor = outOutQueue_.DeQue<float>();
|
||||
DataCopyExtParams dataCopyParams{
|
||||
static_cast<uint16_t>(groupCount_), static_cast<uint32_t>(perGroupExpertCount_ * sizeof(float)),
|
||||
static_cast<uint32_t>((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0};
|
||||
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
|
||||
outOutQueue_.FreeTensor(outOutTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SortInGroup()
|
||||
{
|
||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
||||
LocalTensor<float> tmpLocal = calcTmpBuf_.Get<float>();
|
||||
if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_);
|
||||
} else {
|
||||
for (int64_t group = 0; group < groupCount_; group++) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sort<float, true>(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO],
|
||||
xNormWithBiasTensor[group * perGroupExpertCountAlign_],
|
||||
expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal,
|
||||
perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKGroupIndex()
|
||||
{
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
||||
LocalTensor<float> valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, 0);
|
||||
LocalTensor<uint32_t> maskTensor =
|
||||
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float));
|
||||
LocalTensor<float> topValueInGroupTensor =
|
||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float));
|
||||
LocalTensor<uint32_t> groupIndex =
|
||||
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float));
|
||||
LocalTensor<float> sortedTopValue =
|
||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float));
|
||||
LocalTensor<float> sortTmp =
|
||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float));
|
||||
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
|
||||
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (groupSelectMode_ == 1) { // top2 sum
|
||||
// Extract the first two elements of each group
|
||||
maskTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
|
||||
maskTensor.SetValue(1, static_cast<uint32_t>(0));
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
|
||||
GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = groupCount_;
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride =
|
||||
Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES);
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true,
|
||||
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// Calculate the sum of the first two numbers in each group
|
||||
PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor,
|
||||
Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1,
|
||||
CONSTANT_EIGHT); // Calculate the sum of the two largest numbers in each group
|
||||
} else {
|
||||
maskTensor.SetValue(0, static_cast<uint32_t>(1)); // b0101
|
||||
maskTensor.SetValue(1, static_cast<uint32_t>(0));
|
||||
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
|
||||
GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = groupCount_;
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32);
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true,
|
||||
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
|
||||
}
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
// Generate group indices
|
||||
ArithProgression(groupIndex.ReinterpretCast<int32_t>(), static_cast<int32_t>(0), static_cast<int32_t>(1),
|
||||
groupCount_); // Generate group indices
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM;
|
||||
int duplicateIndex = groupCount_ - duplicateNum;
|
||||
if (duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(topValueInGroupTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1,
|
||||
REPEAT_BLOCKS);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// Sort
|
||||
Sort<float, true>(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32));
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// Extract group indices
|
||||
uint8_t src1Pattern = 2; // Built-in fixed pattern
|
||||
GatherMask(groupIndex, sortedTopValue.template ReinterpretCast<uint32_t>(), src1Pattern, false,
|
||||
static_cast<uint32_t>(0),
|
||||
{1, static_cast<uint8_t>(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt);
|
||||
PipeBarrier<PIPE_V>();
|
||||
duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
duplicateIndex = kGroup_ - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(groupIndex.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS);
|
||||
}
|
||||
|
||||
// Sort the selected group indices in descending order
|
||||
LocalTensor<float> sortedGroupIndex = sortedGroupIndexBuf_.Get<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sort<float, true>(sortedGroupIndex, groupIndex.ReinterpretCast<float>(), groupIndex, sortTmp, Ceil(kGroup_, 32));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertIdx()
|
||||
{
|
||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
||||
LocalTensor<int32_t> sortedGroupIndex = sortedGroupIndexBuf_.Get<int32_t>();
|
||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
||||
LocalTensor<float> mrgSort0Tensor = calcTmpBuf_.Get<float>();
|
||||
|
||||
uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0};
|
||||
uint16_t lenArr[CONSTANT_FOUR] = {
|
||||
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_),
|
||||
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_)};
|
||||
MrgSort4Info params{lenArr, false, 0b1111, 1};
|
||||
MrgSortSrcList<float> srcList;
|
||||
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
|
||||
for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) {
|
||||
int64_t mrgLen = Min(i + 1, CONSTANT_FOUR);
|
||||
if (mrgLen > 1) {
|
||||
if (mrgLen == MERGE_LIST_FOUR) {
|
||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
} else if (mrgLen == MERGE_LIST_THREE) {
|
||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[3] = 0;
|
||||
params.elementLengths[3] = 0;
|
||||
params.validBit = 0b111;
|
||||
} else {
|
||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
||||
offset[2] = 0;
|
||||
offset[3] = 0;
|
||||
params.elementLengths[2] = 0;
|
||||
params.elementLengths[3] = 0;
|
||||
params.validBit = 0b11;
|
||||
}
|
||||
|
||||
srcList.src1 = sortedInGroupTensor[offset[0]];
|
||||
srcList.src2 = sortedInGroupTensor[offset[1]];
|
||||
srcList.src3 = sortedInGroupTensor[offset[2]];
|
||||
srcList.src4 = sortedInGroupTensor[offset[3]];
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params);
|
||||
} else {
|
||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]],
|
||||
perGroupExpertCountAlign_ * 2);
|
||||
}
|
||||
}
|
||||
int32_t baseLoop = 4;
|
||||
LocalTensor<float> srcTensor = mrgSort0Tensor;
|
||||
LocalTensor<float> dstTensor = mrgSort0Tensor;
|
||||
for (int i = 0; i < tilingData_->vmsCount; i++) {
|
||||
if (i % 2 == 0) {
|
||||
srcTensor = mrgSort0Tensor;
|
||||
dstTensor = sortedInGroupTensor;
|
||||
} else {
|
||||
srcTensor = sortedInGroupTensor;
|
||||
dstTensor = mrgSort0Tensor;
|
||||
}
|
||||
|
||||
int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR;
|
||||
int32_t quotient = kGroup_ / nextBaseRow;
|
||||
int32_t remainder = kGroup_ - quotient * nextBaseRow;
|
||||
if (quotient > 0) {
|
||||
MrgSort4Info params;
|
||||
MrgSortSrcList<float> srcList;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
|
||||
params.validBit = 0b1111;
|
||||
params.repeatTimes = 1;
|
||||
for (int j = 0; j < quotient; j++) {
|
||||
srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j];
|
||||
srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)];
|
||||
srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)];
|
||||
srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)];
|
||||
PipeBarrier<PIPE_V>();
|
||||
MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params);
|
||||
}
|
||||
}
|
||||
|
||||
if (remainder > 0) {
|
||||
int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2;
|
||||
int32_t mrgLen = CeilDiv(remainder, baseLoop);
|
||||
int32_t tailRow = remainder - (mrgLen - 1) * baseLoop;
|
||||
if (mrgLen > 1) {
|
||||
MrgSort4Info params;
|
||||
MrgSortSrcList<float> srcList;
|
||||
params.repeatTimes = 1;
|
||||
params.ifExhaustedSuspension = false;
|
||||
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
|
||||
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
|
||||
srcList.src1 = srcTensor[baseOffset];
|
||||
srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2];
|
||||
if (mrgLen == MERGE_LIST_FOUR) {
|
||||
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
|
||||
srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3];
|
||||
params.elementLengths[3] = perGroupExpertCount_ * tailRow;
|
||||
params.validBit = 0b1111;
|
||||
} else if (mrgLen == MERGE_LIST_THREE) {
|
||||
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
|
||||
params.elementLengths[2] = perGroupExpertCount_ * tailRow;
|
||||
params.elementLengths[3] = 0;
|
||||
params.validBit = 0b111;
|
||||
} else {
|
||||
params.elementLengths[1] = perGroupExpertCount_ * tailRow;
|
||||
params.elementLengths[2] = 0;
|
||||
params.elementLengths[3] = 0;
|
||||
params.validBit = 0b11;
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
MrgSort(dstTensor[baseOffset], srcList, params);
|
||||
} else {
|
||||
PipeBarrier<PIPE_V>();
|
||||
DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2);
|
||||
}
|
||||
}
|
||||
baseLoop = nextBaseRow;
|
||||
}
|
||||
|
||||
GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES);
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
|
||||
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
|
||||
uint8_t src1Pattern = 2; // Built-in fixed pattern
|
||||
PipeBarrier<PIPE_V>();
|
||||
GatherMask(topKExpertId, dstTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
||||
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertScore()
|
||||
{
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
|
||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
||||
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
|
||||
k_);
|
||||
bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm
|
||||
(normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm
|
||||
if (needRenorm) {
|
||||
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
|
||||
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[32];
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Duplicate(tmpTensor, sumValue, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(yOutTensor, yOutTensor, tmpTensor, k_);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
|
||||
}
|
||||
|
||||
yOutQueue_.EnQue<float>(yOutTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CumputeActualTopKExpertId()
|
||||
{
|
||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
|
||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
||||
LocalTensor<float> topKExpertIdFp32 = calcTmpBuf_.Get<float>();
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(expertIdxOut, expertIdxOut, static_cast<int32_t>(perGroupExpertCountAlign_ - perGroupExpertCount_), k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sub(expertIdxOut, topKExpertId, expertIdxOut, k_);
|
||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyOut(int64_t row)
|
||||
{
|
||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
|
||||
yOutQueue_.FreeTensor(yOutTensor);
|
||||
expertIdxOutQueue_.FreeTensor(expertIdxOut);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
||||
GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
||||
{
|
||||
tilingData_ = tilingData;
|
||||
pipe_ = tPipe;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
if (blockIdx_ == GetBlockNum() - 1) {
|
||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
||||
} else {
|
||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
}
|
||||
expertCount_ = tilingData_->expertCount;
|
||||
addBias_ = tilingData_->addBias == 1;
|
||||
k_ = tilingData_->k;
|
||||
kGroup_ = tilingData_->kGroup;
|
||||
groupCount_ = tilingData_->groupCount;
|
||||
groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
|
||||
perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign;
|
||||
renorm_ = tilingData_->renorm;
|
||||
normType_ = tilingData_->normType;
|
||||
groupSelectMode_ = tilingData_->groupSelectMode;
|
||||
|
||||
expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float));
|
||||
kAlign_ = Align(k_, sizeof(float));
|
||||
|
||||
isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_;
|
||||
|
||||
// init input gm buf
|
||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
||||
|
||||
// init output gm buf
|
||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
|
||||
// init que
|
||||
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float));
|
||||
pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t));
|
||||
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
|
||||
|
||||
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
|
||||
|
||||
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
|
||||
|
||||
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
|
||||
pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t)));
|
||||
|
||||
pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO);
|
||||
pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t));
|
||||
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::Process()
|
||||
{
|
||||
CopyInBiasAndInitExpertId();
|
||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
||||
CopyInX(row);
|
||||
ComputeX();
|
||||
if (tilingData_->outFlag) {
|
||||
CopuOutXNorm(row);
|
||||
}
|
||||
SortInGroup();
|
||||
SelectTopKGroupIndex();
|
||||
SelectTopKExpertIdx();
|
||||
SelectTopKExpertScore();
|
||||
CumputeActualTopKExpertId();
|
||||
CopyOut(row);
|
||||
}
|
||||
}
|
||||
} // namespace MoeGatingTopK
|
||||
#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
||||
338
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h
Normal file
338
csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h
Normal file
@@ -0,0 +1,338 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_gating_top_k_without_group.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
||||
#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
||||
#include "kernel_operator.h"
|
||||
#include "common.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace MoeGatingTopK {
|
||||
using namespace AscendC;
|
||||
|
||||
template <typename T>
|
||||
class MoeGatingTopKWithoutGroup {
|
||||
public:
|
||||
__aicore__ inline MoeGatingTopKWithoutGroup(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInBiasAndInitExpertId();
|
||||
__aicore__ inline void CopyInX(int64_t progress);
|
||||
__aicore__ inline void ComputeX();
|
||||
__aicore__ inline void CopuOutXNorm(int64_t row);
|
||||
__aicore__ inline void SelectTopKExpertIdx();
|
||||
__aicore__ inline void SelectTopKExpertScore();
|
||||
__aicore__ inline void CopyOut(int64_t row);
|
||||
|
||||
private:
|
||||
TPipe *pipe_;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
||||
|
||||
TBuf<TPosition::VECCALC> biasBuf_; // Store input bias
|
||||
TBuf<TPosition::VECCALC> expertIdBuf_; // Expert ID
|
||||
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // Store value after adding bias
|
||||
TBuf<TPosition::VECCALC> xNormBuf_; // Store value after computing sigmoid or softmax
|
||||
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
|
||||
TBuf<TPosition::VECCALC> calcTmpBuf_;
|
||||
|
||||
GlobalTensor<T> xGm_;
|
||||
GlobalTensor<T> biasGm_;
|
||||
GlobalTensor<T> yGm_;
|
||||
GlobalTensor<int32_t> expertIdxGm_;
|
||||
GlobalTensor<float> outGm_;
|
||||
|
||||
int64_t blockIdx_ = 0;
|
||||
int64_t perCoreRowCount_ = 0;
|
||||
int64_t curCoreRowCount_ = 0;
|
||||
int64_t expertCount_ = 0;
|
||||
bool addBias_ = false;
|
||||
bool outFlag_ = false;
|
||||
int64_t k_ = 0;
|
||||
int64_t renorm_ = 0;
|
||||
int64_t normType_ = 0;
|
||||
int64_t expertCountAlign_ = 0;
|
||||
const MoeGatingTopKTilingData *tilingData_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInBiasAndInitExpertId()
|
||||
{
|
||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
||||
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if (addBias_) {
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
} else {
|
||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
||||
expertCountAlign_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCount_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInX(int64_t row)
|
||||
{
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
||||
} else {
|
||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
||||
dataCopyPadParams);
|
||||
}
|
||||
xInQueue_.EnQue(xInLocalTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::ComputeX()
|
||||
{
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
||||
expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
if (normType_ == 1) { // sigmoid
|
||||
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
|
||||
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else if (normType_ == 0) { // sigmoid
|
||||
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
|
||||
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[8];
|
||||
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_);
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float maxValue = reduceValueTensor.GetValue(0);
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Exp(xNormTensor, xNormTensor, expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_);
|
||||
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float sumValue = reduceValueTensor.GetValue(0);
|
||||
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (addBias_) {
|
||||
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_);
|
||||
} else {
|
||||
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
|
||||
}
|
||||
|
||||
int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM;
|
||||
int duplicateIndex = expertCount_ - duplicateNum;
|
||||
if (duplicateNum > 0) {
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
xInQueue_.FreeTensor(xInLocalTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopuOutXNorm(int64_t row)
|
||||
{
|
||||
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
|
||||
outOutQueue_.EnQue<float>(outOutTensor);
|
||||
outOutTensor = outOutQueue_.DeQue<float>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(float)), 0, 0, 0};
|
||||
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
|
||||
outOutQueue_.FreeTensor(outOutTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertIdx()
|
||||
{
|
||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
|
||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
|
||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
||||
LocalTensor<float> sortedScore = calcTmpBuf_.Get<float>();
|
||||
LocalTensor<float> sortTmp = calcTmpBuf_.Get<float>()[expertCountAlign_ * CONSTANT_TWO];
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
Sort<float, true>(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp,
|
||||
expertCountAlign_ / ONE_REPEAT_SORT_NUM);
|
||||
|
||||
GatherMaskParams gatherMaskParams;
|
||||
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES);
|
||||
gatherMaskParams.src0BlockStride = 1;
|
||||
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
|
||||
gatherMaskParams.src1RepeatStride = 0;
|
||||
|
||||
uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering
|
||||
uint8_t src1Pattern = 2; // Built-in fixed pattern
|
||||
PipeBarrier<PIPE_V>();
|
||||
GatherMask(topKExpertId, sortedScore.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
||||
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
||||
|
||||
DataCopy(expertIdxOut, topKExpertId, expertCountAlign_);
|
||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertScore()
|
||||
{
|
||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
||||
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
|
||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
||||
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
|
||||
k_);
|
||||
|
||||
bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm
|
||||
(normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm
|
||||
if (needRenorm == 1) {
|
||||
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
|
||||
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
|
||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
||||
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
|
||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
||||
Duplicate(tmpTensor, sumValue, k_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(yOutTensor, yOutTensor, tmpTensor, k_);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
|
||||
}
|
||||
|
||||
yOutQueue_.EnQue<float>(yOutTensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyOut(int64_t row)
|
||||
{
|
||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
|
||||
yOutQueue_.FreeTensor(yOutTensor);
|
||||
expertIdxOutQueue_.FreeTensor(expertIdxOut);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
||||
GM_ADDR out, GM_ADDR workspace,
|
||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
||||
{
|
||||
tilingData_ = tilingData;
|
||||
pipe_ = tPipe;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
if (blockIdx_ == GetBlockNum() - 1) {
|
||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
||||
} else {
|
||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
||||
}
|
||||
expertCount_ = tilingData_->expertCount;
|
||||
addBias_ = tilingData_->addBias == 1;
|
||||
outFlag_ = tilingData_->outFlag == 1;
|
||||
k_ = tilingData_->k;
|
||||
renorm_ = tilingData_->renorm;
|
||||
normType_ = tilingData_->normType;
|
||||
expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
|
||||
// init input gm buf
|
||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
||||
|
||||
// init output gm buf
|
||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
||||
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
||||
|
||||
// init que
|
||||
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float));
|
||||
pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t));
|
||||
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
|
||||
|
||||
// init calc buf
|
||||
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
||||
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
|
||||
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
|
||||
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
|
||||
pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t));
|
||||
|
||||
// init tmp buf
|
||||
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Process()
|
||||
{
|
||||
CopyInBiasAndInitExpertId();
|
||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
||||
CopyInX(row);
|
||||
ComputeX();
|
||||
if (outFlag_) {
|
||||
CopuOutXNorm(row);
|
||||
}
|
||||
SelectTopKExpertIdx();
|
||||
SelectTopKExpertScore();
|
||||
CopyOut(row);
|
||||
}
|
||||
}
|
||||
} // namespace MoeGatingTopK
|
||||
#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
||||
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <graph/tensor.h>
|
||||
#include "data_copy_transpose_tiling_def.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
||||
optiling::CopyTransposeTiling &tiling)
|
||||
{
|
||||
constexpr int64_t B_INDEX = 0;
|
||||
constexpr int64_t N_INDEX = 1;
|
||||
constexpr int64_t S_INDEX = 2;
|
||||
constexpr int64_t H_INDEX = 3;
|
||||
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
||||
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
||||
|
||||
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
|
||||
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
|
||||
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
|
||||
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
|
||||
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
||||
|
||||
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
|
||||
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
|
||||
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
|
||||
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
|
||||
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
||||
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
||||
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
||||
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
||||
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
||||
}
|
||||
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file data_copy_transpose_tiling_def.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <register/tilingdata_base.h>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
||||
|
||||
} // namespace optiling
|
||||
56
csrc/moe_gating_top_k/tiling_base/error_log.h
Normal file
56
csrc/moe_gating_top_k/tiling_base/error_log.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
// Modify OP_TILING_CHECK macro to ensure proper handling of expressions
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
256
csrc/moe_gating_top_k/tiling_base/tiling_base.h
Normal file
256
csrc/moe_gating_top_k/tiling_base/tiling_base.h
Normal file
@@ -0,0 +1,256 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <exe_graph/runtime/tiling_context.h>
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "error_log.h"
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
#define ASCENDC_EXTERN_C extern "C"
|
||||
#else
|
||||
#define ASCENDC_EXTERN_C
|
||||
#endif
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
|
||||
struct AiCoreParams {
|
||||
uint64_t ubSize = 0;
|
||||
uint64_t blockDim = 0;
|
||||
uint64_t aicNum = 0;
|
||||
uint64_t l1Size = 0;
|
||||
uint64_t l0aSize = 0;
|
||||
uint64_t l0bSize = 0;
|
||||
uint64_t l0cSize = 0;
|
||||
};
|
||||
|
||||
struct CompileInfoCommon {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
int32_t socVersion;
|
||||
uint32_t rsvd;
|
||||
};
|
||||
|
||||
struct FlashAttentionScoreGradCompileInfo {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
platform_ascendc::SocVersion socVersion;
|
||||
};
|
||||
|
||||
struct FACompileInfoCommon {
|
||||
uint32_t aivNum;
|
||||
uint32_t aicNum;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
uint64_t l2CacheSize;
|
||||
int64_t coreNum;
|
||||
int32_t socVersion;
|
||||
uint32_t rsvd;
|
||||
};
|
||||
|
||||
class TilingBaseClass {
|
||||
public:
|
||||
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
|
||||
{}
|
||||
|
||||
virtual ~TilingBaseClass() = default;
|
||||
|
||||
// Tiling execution framework
|
||||
// 1. GRAPH_SUCCESS: Success, and no need to continue executing subsequent Tiling class implementations
|
||||
// 2. GRAPH_FAILED: Failure, abort the entire Tiling process
|
||||
// 3. GRAPH_PARAM_INVALID: This class does not support, need to continue executing other Tiling class implementations
|
||||
ge::graphStatus DoTiling()
|
||||
{
|
||||
auto ret = GetShapeAttrsInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetPlatformInfo();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsCapable()) {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ret = DoOpTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = DoLibApiTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = GetWorkspaceSize();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
ret = PostTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
context_->SetTilingKey(GetTilingKey());
|
||||
DumpTilingInfo();
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// Update context
|
||||
virtual void Reset(gert::TilingContext* context)
|
||||
{
|
||||
context_ = context;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool IsCapable() = 0;
|
||||
// 1. Get platform information such as CoreNum, UB/L1/L0C resource sizes
|
||||
virtual ge::graphStatus GetPlatformInfo() = 0;
|
||||
// 2. Get INPUT/OUTPUT/ATTR information
|
||||
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
||||
// 3. Calculate data splitting TilingData
|
||||
virtual ge::graphStatus DoOpTiling() = 0;
|
||||
// 4. Calculate high-level API TilingData
|
||||
virtual ge::graphStatus DoLibApiTiling() = 0;
|
||||
// 5. Calculate TilingKey
|
||||
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
||||
// 6. Calculate Workspace size
|
||||
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
||||
// 7. Save Tiling data
|
||||
virtual ge::graphStatus PostTiling() = 0;
|
||||
// 8. Dump Tiling data
|
||||
virtual void DumpTilingInfo()
|
||||
{
|
||||
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
||||
if (enable != 1) {
|
||||
return;
|
||||
}
|
||||
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
|
||||
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
||||
std::ostringstream oss;
|
||||
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
|
||||
<< ", content:";
|
||||
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
||||
oss << *(buf + i) << ",";
|
||||
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||
oss.str("");
|
||||
}
|
||||
}
|
||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
||||
}
|
||||
|
||||
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
||||
{
|
||||
uint32_t ration;
|
||||
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
||||
return sliceNum;
|
||||
}
|
||||
ration = aivCoreNum / aicCoreNum;
|
||||
return (sliceNum + (ration - 1)) / ration;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
if (shape.GetDimNum() > 0) {
|
||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
||||
oss << shape.GetDim(i) << ", ";
|
||||
}
|
||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTensorDebugStr(
|
||||
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
|
||||
{
|
||||
if (shape == nullptr || tensor == nullptr) {
|
||||
return "nil ";
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
||||
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
||||
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
||||
oss << "(format: "
|
||||
<< ge::TypeUtils::FormatToSerialString(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
||||
<< "),";
|
||||
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingContextDebugStr()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
||||
oss << "input" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
||||
oss << "output" << i << ": ";
|
||||
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
||||
{
|
||||
auto rawTilingData = context_->GetRawTilingData();
|
||||
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
||||
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
|
||||
size_t len = rawTilingDataSize / sizeof(int32_t);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
oss << data[i] << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
gert::TilingContext* context_ = nullptr;
|
||||
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
AiCoreParams aicoreParams_;
|
||||
};
|
||||
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
63
csrc/moe_gating_top_k/tiling_base/tiling_key.h
Normal file
63
csrc/moe_gating_top_k/tiling_base/tiling_key.h
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_key.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr uint64_t kBase = 10; // Base-10 carry base
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey generation rules:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
|
||||
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
|
||||
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
|
||||
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
|
||||
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
|
||||
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
|
||||
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
|
||||
// For other specialized scenarios, define your own bit fields and values
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace Optiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
351
csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h
Normal file
351
csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_templates_registry.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "exe_graph/runtime/tiling_context.h"
|
||||
#include "tiling_base.h"
|
||||
#include "error_log.h"
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
|
||||
{
|
||||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
||||
}
|
||||
|
||||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
|
||||
|
||||
class TilingCases {
|
||||
public:
|
||||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
void AddTiling(int32_t priority)
|
||||
{
|
||||
OP_CHECK_IF(
|
||||
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
|
||||
cases_[priority] = TILING_CLASS<T>;
|
||||
OP_CHECK_IF(
|
||||
cases_[priority] == nullptr,
|
||||
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingCases()
|
||||
{
|
||||
return cases_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, TilingClassCase> cases_;
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// --------------------------------Interfacce with soc version --------------------------------
|
||||
class TilingRegistryNew {
|
||||
public:
|
||||
TilingRegistryNew() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistryNew& GetInstance();
|
||||
#else
|
||||
static TilingRegistryNew& GetInstance()
|
||||
{
|
||||
static TilingRegistryNew registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
|
||||
{
|
||||
auto soc_iter = registry_map_.find(soc_version);
|
||||
if (soc_iter == registry_map_.end()) {
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
|
||||
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
registry_map_[soc_version] = op_type_map;
|
||||
} else {
|
||||
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
|
||||
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
}
|
||||
|
||||
OP_CHECK_IF(
|
||||
registry_map_[soc_version][op_type] == nullptr,
|
||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||
return registry_map_[soc_version][op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||
{
|
||||
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
|
||||
const char* op_type = context->GetNodeType();
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
if (platformInfoPtr == nullptr) {
|
||||
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||
OP_CHECK_IF(
|
||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||
soc_version = compileInfoPtr->socVersion;
|
||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||
} else {
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||
OP_LOGD(context, "soc version is %d", soc_version);
|
||||
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
|
||||
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||
{
|
||||
int32_t soc_version;
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto platformInfoPtr = context->GetPlatformInfo();
|
||||
if (platformInfoPtr == nullptr) {
|
||||
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
||||
OP_CHECK_IF(
|
||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
||||
soc_version = compileInfoPtr->socVersion;
|
||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
||||
} else {
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
||||
OP_LOGD(context, "soc version is %d", soc_version);
|
||||
}
|
||||
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
||||
for (auto priority_id : priorities) {
|
||||
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
|
||||
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
|
||||
auto templateFunc = tilingCaseIter->second(context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
|
||||
{
|
||||
auto soc_iter = registry_map_.find(soc_version);
|
||||
OP_CHECK_IF(
|
||||
soc_iter == registry_map_.end(),
|
||||
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
|
||||
return empty_tiling_case_);
|
||||
auto op_iter = soc_iter->second.find(op_type);
|
||||
OP_CHECK_IF(
|
||||
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
|
||||
return empty_tiling_case_);
|
||||
return op_iter->second->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
|
||||
};
|
||||
|
||||
class RegisterNew {
|
||||
public:
|
||||
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
RegisterNew& tiling(int32_t priority, int32_t soc_version)
|
||||
{
|
||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
|
||||
{
|
||||
for (int32_t soc_version : soc_versions) {
|
||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
|
||||
return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
|
||||
// --------------------------------Interfacce without soc version --------------------------------
|
||||
class TilingRegistry {
|
||||
public:
|
||||
TilingRegistry() = default;
|
||||
|
||||
#ifdef ASCENDC_OP_TEST
|
||||
static TilingRegistry& GetInstance();
|
||||
#else
|
||||
static TilingRegistry& GetInstance()
|
||||
{
|
||||
static TilingRegistry registry_impl_;
|
||||
return registry_impl_;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
|
||||
{
|
||||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
||||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
||||
}
|
||||
OP_CHECK_IF(
|
||||
registry_map_[op_type] == nullptr,
|
||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
||||
return registry_map_[op_type];
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
||||
{
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
||||
auto tilingTemplate = it->second(context);
|
||||
if (tilingTemplate != nullptr) {
|
||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
||||
{
|
||||
const char* op_type = context->GetNodeType();
|
||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
||||
for (auto priorityId : priorities) {
|
||||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
||||
if (templateFunc != nullptr) {
|
||||
ge::graphStatus status = templateFunc->DoTiling();
|
||||
if (status == ge::GRAPH_SUCCESS) {
|
||||
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
|
||||
return status;
|
||||
}
|
||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
||||
OP_LOGD(context, "Do op tiling failed");
|
||||
return status;
|
||||
}
|
||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
|
||||
}
|
||||
}
|
||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
|
||||
{
|
||||
OP_CHECK_IF(
|
||||
registry_map_.find(op_type) == registry_map_.end(),
|
||||
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
|
||||
return registry_map_[op_type]->GetTilingCases();
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
|
||||
};
|
||||
|
||||
class Register {
|
||||
public:
|
||||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
||||
{}
|
||||
|
||||
template <typename T>
|
||||
Register& tiling(int32_t priority)
|
||||
{
|
||||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
||||
OP_CHECK_IF(
|
||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
||||
tilingCases->AddTiling<T>(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string op_type_;
|
||||
};
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
|
||||
// op_type: operator name, class_name: registered tiling class, soc_version: chip version number
|
||||
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||||
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
|
||||
|
||||
// op_type: operator name, class_name: registered tiling class
|
||||
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
|
||||
|
||||
// op_type: operator name, class_name: registered tiling class
|
||||
// soc_version: SOC version, used to distinguish different SOCs
|
||||
// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first
|
||||
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
|
||||
|
||||
// op_type: operator name, class_name: registered tiling class
|
||||
// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected
|
||||
// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes
|
||||
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
|
||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
||||
static Ops::Transformer::OpTiling::Register \
|
||||
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
|
||||
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)
|
||||
139
csrc/moe_gating_top_k/tiling_base/tiling_type.h
Normal file
139
csrc/moe_gating_top_k/tiling_base/tiling_type.h
Normal file
@@ -0,0 +1,139 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_type.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
enum class AxisEnum {
|
||||
B = 0,
|
||||
N2 = 1,
|
||||
G = 2,
|
||||
S1 = 3,
|
||||
S2 = 4,
|
||||
D = 5,
|
||||
NONE = 9,
|
||||
};
|
||||
|
||||
enum class DtypeEnum {
|
||||
FLOAT16 = 0,
|
||||
FLOAT32 = 1,
|
||||
BFLOAT16 = 2,
|
||||
FLOAT16_PRECISION = 3,
|
||||
};
|
||||
|
||||
enum class PerformanceOrientedEnum {
|
||||
BIG_BUFFER = 1,
|
||||
BIG_DOUBLE_BUFFER = 2,
|
||||
};
|
||||
|
||||
enum class MatmulConfig {
|
||||
NULL_CONFIG = 0,
|
||||
NORMAL_CONFIG = 1,
|
||||
MDL_CONFIG = 2
|
||||
};
|
||||
|
||||
enum class PseConfig {
|
||||
NO_PSE = 0,
|
||||
EXIST_PSE = 1
|
||||
};
|
||||
|
||||
enum class AttenMaskConfig {
|
||||
NO_ATTEN_MASK = 0,
|
||||
EXIST_ATTEN_MASK = 1
|
||||
};
|
||||
|
||||
enum class DropOutConfig {
|
||||
NO_DROP_OUT = 0,
|
||||
EXIST_DROP_OUT = 1
|
||||
};
|
||||
|
||||
enum class CubeFormatEnum {
|
||||
ND = 0,
|
||||
NZ = 1
|
||||
};
|
||||
enum class LayoutEnum {
|
||||
BSND = 0,
|
||||
SBND = 1,
|
||||
BNSD = 2,
|
||||
TND = 3,
|
||||
NTD_TND = 4
|
||||
};
|
||||
|
||||
enum class CubeInputSourceEnum {
|
||||
GM = 0,
|
||||
L1 = 1
|
||||
};
|
||||
|
||||
enum class OptionEnum {
|
||||
DISABLE = 0,
|
||||
ENABLE = 1
|
||||
};
|
||||
|
||||
enum class SparseEnum {
|
||||
ALL = 0,
|
||||
NONE = 1,
|
||||
ANY = 2,
|
||||
CAUSAL = 3,
|
||||
BAND = 4,
|
||||
PREFIX = 5,
|
||||
BAND_COMPRESS = 6,
|
||||
RIGHT_DOWN_CAUSAL = 7,
|
||||
RIGHT_DOWN_CAUSAL_BAND = 8,
|
||||
BAND_LEFT_UP_CAUSAL = 9
|
||||
};
|
||||
|
||||
constexpr uint64_t RecursiveSum()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
constexpr int64_t base10Multiplier = 10;
|
||||
|
||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
||||
{
|
||||
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// TilingKey generation rules:
|
||||
// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1,
|
||||
// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1:
|
||||
// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting,
|
||||
// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit;
|
||||
// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit;
|
||||
// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit
|
||||
// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit
|
||||
// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit
|
||||
// For other specialized scenarios, define your own bit fields and values
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
||||
|
||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
||||
{
|
||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
||||
}
|
||||
|
||||
// usage: get tilingKey from inputed types
|
||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
||||
|
||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
||||
SparseEnum::sparse))
|
||||
|
||||
} // namespace optiling
|
||||
30
csrc/moe_gating_top_k/tiling_base/tiling_util.h
Normal file
30
csrc/moe_gating_top_k/tiling_base/tiling_util.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file tiling_util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "register/op_impl_registry.h"
|
||||
|
||||
namespace Ops {
|
||||
namespace Transformer {
|
||||
namespace OpTiling {
|
||||
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
|
||||
|
||||
bool IsRegbaseSocVersion(const gert::TilingContext* context);
|
||||
|
||||
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
|
||||
} // namespace OpTiling
|
||||
} // namespace Transformer
|
||||
} // namespace Ops
|
||||
@@ -1219,10 +1219,65 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
||||
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
|
||||
const at::Tensor& x,
|
||||
int64_t k,
|
||||
int64_t k_group,
|
||||
int64_t group_count,
|
||||
int64_t group_select_mode,
|
||||
int64_t renorm,
|
||||
int64_t norm_type,
|
||||
bool out_flag,
|
||||
double routed_scaling_factor,
|
||||
double eps,
|
||||
const c10::optional<at::Tensor>& bias_opt
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
|
||||
TORCH_CHECK(
|
||||
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
|
||||
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
|
||||
x.scalar_type());
|
||||
|
||||
auto x_size = x.sizes();
|
||||
auto rows = x_size[0];
|
||||
auto expert_num = x_size[1];
|
||||
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
|
||||
if (bias.defined()) {
|
||||
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
|
||||
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
|
||||
auto bias_size = bias.sizes();
|
||||
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
|
||||
}
|
||||
at::Tensor y = at::empty({rows, k}, x.options());
|
||||
at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt));
|
||||
at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeGatingTopK,
|
||||
x,
|
||||
bias,
|
||||
k,
|
||||
k_group,
|
||||
group_count,
|
||||
group_select_mode,
|
||||
renorm,
|
||||
norm_type,
|
||||
out_flag,
|
||||
routed_scaling_factor,
|
||||
eps,
|
||||
y,
|
||||
expert_idx,
|
||||
out
|
||||
);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
{
|
||||
|
||||
// vLLM-Ascend custom ops
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
||||
@@ -1358,6 +1413,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"num_ranks) -> Tensor");
|
||||
ops.impl("combine_prefill", torch::kPrivateUse1,
|
||||
&vllm_ascend::combine_prefill);
|
||||
|
||||
ops.def(
|
||||
"npu_moe_init_routing_custom(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, "
|
||||
" int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, "
|
||||
@@ -1365,4 +1421,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)"
|
||||
);
|
||||
ops.impl("npu_moe_init_routing_custom", torch::kPrivateUse1, &vllm_ascend::npu_moe_init_routing_custom);
|
||||
// vLLM-Ascend custom ops
|
||||
ops.def(
|
||||
"moe_gating_top_k(Tensor x, "
|
||||
"int k, "
|
||||
"int k_group, "
|
||||
"int group_count, "
|
||||
"int group_select_mode, "
|
||||
"int renorm, "
|
||||
"int norm_type, "
|
||||
"bool out_flag, "
|
||||
"float routed_scaling_factor, "
|
||||
"float eps,"
|
||||
"Tensor? bias_opt=None)"
|
||||
|
||||
"-> (Tensor y ,Tensor expert_idx, Tensor out)"
|
||||
);
|
||||
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
|
||||
}
|
||||
|
||||
@@ -366,7 +366,43 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
||||
at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat));
|
||||
return {expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale};
|
||||
}
|
||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
|
||||
const at::Tensor& x,
|
||||
int64_t k,
|
||||
int64_t k_group,
|
||||
int64_t group_count,
|
||||
int64_t group_select_mode,
|
||||
int64_t renorm,
|
||||
int64_t norm_type,
|
||||
bool out_flag,
|
||||
double routed_scaling_factor,
|
||||
double eps,
|
||||
const c10::optional<at::Tensor>& bias_opt
|
||||
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
|
||||
TORCH_CHECK(
|
||||
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
|
||||
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
|
||||
x.scalar_type());
|
||||
|
||||
auto x_size = x.sizes();
|
||||
auto rows = x_size[0];
|
||||
auto expert_num = x_size[1];
|
||||
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
|
||||
if (bias.defined()) {
|
||||
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
|
||||
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
|
||||
auto bias_size = bias.sizes();
|
||||
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
|
||||
}
|
||||
at::Tensor y = at::empty({rows, k}, x.options());
|
||||
at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt));
|
||||
at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -374,6 +410,7 @@ namespace {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
@@ -402,5 +439,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
|
||||
// moe_init_routing_custom
|
||||
ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta);
|
||||
// Moe_gating_top_k
|
||||
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,8 +305,8 @@ def test_select_experts(
|
||||
)
|
||||
|
||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||
hidden_states, topk, topk_group, num_expert_group, scoring_func,
|
||||
custom_routing_function)
|
||||
hidden_states, topk, renormalize, topk_group, num_expert_group,
|
||||
scoring_func, custom_routing_function)
|
||||
if not call_moe_gatingtopk and use_grouped_topk:
|
||||
mock_native_grouped_topk.assert_called_once()
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
# Fix random seed to ensure test reproducibility
|
||||
RTOL_TOLERANCE = 1e-5
|
||||
ATOL_TOLERANCE = 1e-8
|
||||
seed = 45
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def softmax_func(x, axis=None):
|
||||
"""Softmax implementation (adapted for numpy calculation)"""
|
||||
if "float16" in x.dtype.name:
|
||||
x = x.astype(numpy.float32)
|
||||
x_max = x.max(axis=axis, keepdims=True)
|
||||
x_sub = x - x_max
|
||||
y = numpy.exp(x_sub)
|
||||
x_sum = y.sum(axis=axis, keepdims=True)
|
||||
res = y / x_sum
|
||||
return res, x_max, x_sum
|
||||
|
||||
|
||||
def moe_gating_top_k_numpy_ref(x: torch.Tensor,
|
||||
k: int,
|
||||
bias: torch.Tensor | None,
|
||||
k_group: int = 1,
|
||||
group_count: int = 1,
|
||||
group_select_mode: int = 0,
|
||||
renorm: int = 0,
|
||||
norm_type: int = 0,
|
||||
y2_flag: bool = False,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
eps: float = 1e-20) -> tuple:
|
||||
"""NumPy reference implementation of MOE Gating TopK.
|
||||
|
||||
For result comparison with NPU operator, ensure the consistency
|
||||
between NPU kernel and baseline implementation.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (num_tokens, num_experts)
|
||||
k: Number of top-k experts to select
|
||||
bias: Bias tensor of shape (num_experts,) (optional)
|
||||
k_group: Number of top-k groups to select
|
||||
group_count: Number of expert groups
|
||||
group_select_mode: Group selection mode (0: max, 1: top2 sum)
|
||||
renorm: Whether to renormalize the output (0/1)
|
||||
norm_type: Normalization type (0: softmax, 1: sigmoid)
|
||||
y2_flag: Whether to output original x as y2
|
||||
routed_scaling_factor: Scaling factor for routing weights
|
||||
eps: Small epsilon to avoid division by zero
|
||||
|
||||
Returns:
|
||||
tuple: (y, indices, y2)
|
||||
- y: Top-k weights of shape (num_tokens, k)
|
||||
- indices: Top-k expert indices of shape (num_tokens, k)
|
||||
- y2: Original x if y2_flag is True, else None
|
||||
"""
|
||||
dtype = x.dtype
|
||||
if dtype != torch.float32:
|
||||
x = x.to(dtype=torch.float32)
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=torch.float32)
|
||||
|
||||
x = x.numpy()
|
||||
if bias is not None:
|
||||
bias = bias.numpy()
|
||||
|
||||
if norm_type == 0: # softmax normalization
|
||||
x, _, _ = softmax_func(x, -1)
|
||||
else: # sigmoid normalization
|
||||
x = 1 / (1 + numpy.exp(-x))
|
||||
|
||||
original_x = x
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
|
||||
if group_count > 1:
|
||||
x = x.reshape(x.shape[0], group_count, -1)
|
||||
if group_select_mode == 0:
|
||||
group_x = numpy.amax(x, axis=-1)
|
||||
else:
|
||||
group_x = numpy.partition(x, -2, axis=-1)[..., -2:].sum(axis=-1)
|
||||
indices = numpy.argsort(-group_x, axis=-1, kind='stable')[:, :k_group]
|
||||
|
||||
mask = numpy.ones((x.shape[0], group_count), dtype=bool)
|
||||
mask[numpy.arange(x.shape[0])[:, None], indices] = False
|
||||
x = numpy.where(mask[..., None], float('-inf'), x)
|
||||
x = x.reshape(x.shape[0], -1)
|
||||
|
||||
_, indices = torch.sort(torch.from_numpy(x),
|
||||
dim=-1,
|
||||
stable=True,
|
||||
descending=True)
|
||||
indices = numpy.asarray(indices[:, :k])
|
||||
|
||||
y = numpy.take_along_axis(original_x, indices, axis=1)
|
||||
if norm_type == 1 or renorm == 1:
|
||||
y /= (numpy.sum(y, axis=-1, keepdims=True) + eps)
|
||||
y *= routed_scaling_factor
|
||||
|
||||
y2 = original_x if y2_flag else None
|
||||
y = torch.tensor(y, dtype=dtype)
|
||||
return y, indices.astype(numpy.int32), y2
|
||||
|
||||
|
||||
# pytest parameterized decorators (cover all test scenarios)
|
||||
@pytest.mark.parametrize("group_select_mode", [0, 1])
|
||||
@pytest.mark.parametrize("renorm", [1])
|
||||
@pytest.mark.parametrize("norm_type", [0, 1])
|
||||
@pytest.mark.parametrize("group_count", [1, 8])
|
||||
@pytest.mark.parametrize("k_ranges", [4, 8, 12, 16, 6, 32])
|
||||
@pytest.mark.parametrize("x_dim0_range", list(range(1, 17)))
|
||||
@pytest.mark.parametrize("x_dim1_range", [256, 128, 64, 208, 192, 160])
|
||||
def test_npu_moe_gating_topk_compare(group_select_mode: int,
|
||||
renorm: int,
|
||||
norm_type: int,
|
||||
group_count: int,
|
||||
k_ranges: int,
|
||||
x_dim0_range: int,
|
||||
x_dim1_range: int,
|
||||
device: str = "npu"):
|
||||
"""Ascend NPU MOE Gating TopK operator test.
|
||||
|
||||
Compare NPU kernel results with NumPy reference implementation
|
||||
to verify the correctness of Ascend custom op.
|
||||
|
||||
Args:
|
||||
group_select_mode: Group selection mode (0: max, 1: top2 sum)
|
||||
renorm: Whether to renormalize output (fixed to 1 in test)
|
||||
norm_type: Normalization type (0: softmax, 1: sigmoid)
|
||||
group_count: Number of expert groups
|
||||
k_ranges: Number of top-k experts to select
|
||||
x_dim0_range: First dimension of input tensor (num_tokens)
|
||||
x_dim1_range: Second dimension of input tensor (num_experts)
|
||||
device: Target device (fixed to "npu" in test)
|
||||
"""
|
||||
# Simplify parameter names for better readability
|
||||
k = k_ranges
|
||||
dim0 = x_dim0_range
|
||||
dim1 = x_dim1_range
|
||||
|
||||
# Skip invalid cases: k cannot exceed num_experts per group
|
||||
if k > dim1 // group_count:
|
||||
return
|
||||
|
||||
# Construct test inputs
|
||||
x = numpy.random.uniform(-2, 2, (dim0, dim1)).astype(numpy.float32)
|
||||
bias = numpy.random.uniform(-2, 2, (dim1, )).astype(numpy.float32)
|
||||
|
||||
x_tensor = torch.tensor(x, dtype=torch.float32)
|
||||
bias_tensor = torch.tensor(bias, dtype=torch.float32)
|
||||
# Fix k_group value to avoid irreproducibility caused by random.randint
|
||||
k_group = min(1, group_count)
|
||||
out_flag = False
|
||||
routed_scaling_factor = 1.0
|
||||
eps = 1e-20
|
||||
|
||||
# Calculate NumPy reference results
|
||||
y, expert_idx, out = moe_gating_top_k_numpy_ref(
|
||||
x_tensor,
|
||||
k=k,
|
||||
bias=bias_tensor,
|
||||
k_group=k_group,
|
||||
group_count=group_count,
|
||||
group_select_mode=group_select_mode,
|
||||
renorm=renorm,
|
||||
norm_type=norm_type,
|
||||
y2_flag=out_flag,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
# Calculate NPU operator results
|
||||
y_npu, expert_idx_npu, out_npu = torch.ops._C_ascend.moe_gating_top_k(
|
||||
x_tensor.npu(),
|
||||
k=k,
|
||||
k_group=k_group,
|
||||
group_count=group_count,
|
||||
group_select_mode=group_select_mode,
|
||||
renorm=renorm,
|
||||
norm_type=norm_type,
|
||||
out_flag=out_flag,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=eps,
|
||||
bias_opt=bias_tensor.npu() if bias_tensor is not None else None,
|
||||
)
|
||||
|
||||
# Verify consistency between NPU and NumPy results
|
||||
assert numpy.allclose(y.cpu().numpy(),
|
||||
y_npu.cpu().numpy(),
|
||||
rtol=RTOL_TOLERANCE,
|
||||
atol=ATOL_TOLERANCE)
|
||||
assert numpy.allclose(expert_idx,
|
||||
expert_idx_npu.cpu().numpy(),
|
||||
rtol=RTOL_TOLERANCE,
|
||||
atol=ATOL_TOLERANCE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Execute pytest tests with verbose output
|
||||
pytest.main([__file__, "-sv"])
|
||||
@@ -17,7 +17,6 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import get_weight_prefetch_method
|
||||
|
||||
@@ -64,6 +63,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
|
||||
hidden_states=hidden_states,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
@@ -102,10 +102,13 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
def check_npu_moe_gating_top_k(
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
scoring_func: str = "softmax",
|
||||
custom_routing_function: Optional[Callable] = None):
|
||||
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
|
||||
return False
|
||||
if custom_routing_function is not None:
|
||||
return False
|
||||
if scoring_func != "softmax" and scoring_func != "sigmoid":
|
||||
@@ -209,26 +212,25 @@ def _select_experts_with_fusion_ops(
|
||||
|
||||
topk_group = topk_group if topk_group is not None else 1
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||
renorm = int(renormalize)
|
||||
norm_type = 0 if scoring_func == "softmax" else 1
|
||||
if e_score_correction_bias is not None and \
|
||||
e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(
|
||||
router_logits.dtype)
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
group_select_mode=1,
|
||||
renorm=renorm,
|
||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
out_flag=False,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20))
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
eps=float(1e-20),
|
||||
bias_opt=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user