Revert "moe_gating_top_k" (#5512)
Reverts vllm-project/vllm-ascend#5271
It breaks e2e test
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
|||||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||||
|
|
||||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k"
|
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom"
|
||||||
SOC_ARG="ascend910b"
|
SOC_ARG="ascend910b"
|
||||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||||
# ASCEND910C (A3) series
|
# ASCEND910C (A3) series
|
||||||
@@ -70,7 +70,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
|||||||
"dispatch_layout"
|
"dispatch_layout"
|
||||||
"notify_dispatch"
|
"notify_dispatch"
|
||||||
"moe_init_routing_custom"
|
"moe_init_routing_custom"
|
||||||
"moe_gating_top_k"
|
|
||||||
)
|
)
|
||||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||||
SOC_ARG="ascend910_93"
|
SOC_ARG="ascend910_93"
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
# ----------------------------------------------------------------------------
|
|
||||||
# This program is free software, you can redistribute it and/or modify.
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
# This file is a part of the CANN Open Software.
|
|
||||||
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
# See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
add_ops_compile_options(
|
|
||||||
OP_NAME MoeGatingTopK
|
|
||||||
OPTIONS --cce-auto-sync=on
|
|
||||||
-Wno-deprecated-declarations
|
|
||||||
-Werror
|
|
||||||
)
|
|
||||||
|
|
||||||
# Host 侧算子实现(aclnn)
|
|
||||||
if (BUILD_OPEN_PROJECT)
|
|
||||||
target_sources(op_host_aclnn PRIVATE
|
|
||||||
moe_gating_top_k_def.cpp
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Tiling 模块
|
|
||||||
target_sources(optiling PRIVATE
|
|
||||||
moe_gating_top_k_tiling.cpp
|
|
||||||
moe_gating_top_k_tiling_base.cpp
|
|
||||||
moe_gating_top_k_tiling_arch35.cpp
|
|
||||||
)
|
|
||||||
target_sources(opsproto PRIVATE
|
|
||||||
moe_gating_top_k_proto.cpp
|
|
||||||
moe_gating_top_k_infershape.cpp
|
|
||||||
|
|
||||||
)
|
|
||||||
target_include_directories(optiling PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "toolchain/slog.h"
|
|
||||||
|
|
||||||
#define OP_LOGI(opname, ...)
|
|
||||||
#define OP_LOGW(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGD(opname, ...)
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
|
||||||
do { \
|
|
||||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
|
|
||||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
|
||||||
do { \
|
|
||||||
if (cond) { \
|
|
||||||
log_func; \
|
|
||||||
expr; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
|
||||||
do { \
|
|
||||||
if ((ptr) == nullptr) { \
|
|
||||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
|
||||||
return ge::GRAPH_FAILED; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
|
|
||||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
|
||||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
|
||||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file math_util.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef TILING_MATMUL_MATH_UTIL_H
|
|
||||||
#define TILING_MATMUL_MATH_UTIL_H
|
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
namespace matmul_tiling {
|
|
||||||
class MathUtil {
|
|
||||||
public:
|
|
||||||
static bool IsEqual(float leftValue, float rightValue);
|
|
||||||
template<typename T>
|
|
||||||
static auto CeilDivision(T num1, T num2) -> T
|
|
||||||
{
|
|
||||||
if (num2 == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return static_cast<T>((static_cast<int64_t>(num1) + static_cast<int64_t>(num2) - 1) /
|
|
||||||
static_cast<int64_t>(num2));
|
|
||||||
}
|
|
||||||
template<typename T>
|
|
||||||
static auto Align(T num1, T num2) -> T
|
|
||||||
{
|
|
||||||
return CeilDivision(num1, num2) * num2;
|
|
||||||
}
|
|
||||||
static int32_t AlignDown(int32_t num1, int32_t num2);
|
|
||||||
static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c);
|
|
||||||
static int32_t MapShape(int32_t shape, bool roundUpFlag = true);
|
|
||||||
static void AddFactor(std::vector<int32_t> &dimsFactors, int32_t dim);
|
|
||||||
static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
|
||||||
const int32_t factorEnd);
|
|
||||||
static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart,
|
|
||||||
const int32_t factorEnd);
|
|
||||||
static bool CheckFactorNumSatisfy(const int32_t dim);
|
|
||||||
static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum,
|
|
||||||
bool isKDim);
|
|
||||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor);
|
|
||||||
static void GetFactors(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
|
||||||
static void GetBlockFactors(std::vector<int32_t> &factorList, const int32_t oriShape, const int32_t mpShape,
|
|
||||||
const int32_t coreNum, const int32_t maxNum);
|
|
||||||
static int32_t GetNonFactorMap(std::vector<int32_t> &factorList, int32_t srcNum, int32_t maxFactor);
|
|
||||||
static std::vector<std::pair<int, int>> GetFactorPairs(int32_t num);
|
|
||||||
static std::pair<int32_t, int32_t> DivideIntoMainAndTail(int32_t num, int32_t divisor);
|
|
||||||
};
|
|
||||||
} // namespace matmul_tiling
|
|
||||||
#endif // _MATH_UTIL_H_
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_def.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#include "register/op_def_registry.h"
|
|
||||||
|
|
||||||
namespace ops {
|
|
||||||
class MoeGatingTopK : public OpDef {
|
|
||||||
public:
|
|
||||||
explicit MoeGatingTopK(const char *name) : OpDef(name)
|
|
||||||
{
|
|
||||||
this->Input("x")
|
|
||||||
.ParamType(REQUIRED)
|
|
||||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.AutoContiguous();
|
|
||||||
this->Input("bias")
|
|
||||||
.ParamType(OPTIONAL)
|
|
||||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.AutoContiguous();
|
|
||||||
this->Output("y")
|
|
||||||
.ParamType(REQUIRED)
|
|
||||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
|
||||||
this->Output("expert_idx")
|
|
||||||
.ParamType(REQUIRED)
|
|
||||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
|
||||||
this->Output("out")
|
|
||||||
.ParamType(REQUIRED)
|
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
|
||||||
this->Attr("k").Int();
|
|
||||||
this->Attr("k_group").AttrType(OPTIONAL).Int(1);
|
|
||||||
this->Attr("group_count").AttrType(OPTIONAL).Int(1);
|
|
||||||
this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0);
|
|
||||||
this->Attr("renorm").AttrType(OPTIONAL).Int(0);
|
|
||||||
this->Attr("norm_type").AttrType(OPTIONAL).Int(0);
|
|
||||||
this->Attr("out_flag").AttrType(OPTIONAL).Bool(false);
|
|
||||||
this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0);
|
|
||||||
this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f);
|
|
||||||
this->AICore().AddConfig("ascend910b");
|
|
||||||
this->AICore().AddConfig("ascend910_93");
|
|
||||||
|
|
||||||
OpAICoreConfig regbaseCfg;
|
|
||||||
regbaseCfg.DynamicCompileStaticFlag(true)
|
|
||||||
.DynamicRankSupportFlag(true)
|
|
||||||
.DynamicShapeSupportFlag(true)
|
|
||||||
.ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt");
|
|
||||||
this->AICore().AddConfig("ascend910_95", regbaseCfg);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
OP_ADD(MoeGatingTopK);
|
|
||||||
} // namespace ops
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/* !
|
|
||||||
* \file moe_gating_top_k_infershape.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "exe_graph/runtime/infer_shape_context.h"
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#define TO_STRING(x) std::string(#x)
|
|
||||||
|
|
||||||
using namespace ge;
|
|
||||||
namespace ops {
|
|
||||||
static constexpr size_t DIM_ONE = 1;
|
|
||||||
static constexpr size_t DIM_TWO = 2;
|
|
||||||
static constexpr int64_t NEG_ONE = -1;
|
|
||||||
static constexpr int64_t X_INDEX = 0;
|
|
||||||
static constexpr int64_t BIAS_INDEX = 1;
|
|
||||||
static constexpr int64_t Y_INDEX = 0;
|
|
||||||
static constexpr int64_t EXPERT_IDX_INDEX = 1;
|
|
||||||
static constexpr int64_t OUT_INDEX = 2;
|
|
||||||
|
|
||||||
static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape)
|
|
||||||
{
|
|
||||||
int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
|
|
||||||
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
|
|
||||||
if (XRows < NEG_ONE || expertNum < NEG_ONE) {
|
|
||||||
OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str());
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape,
|
|
||||||
const int64_t k)
|
|
||||||
{
|
|
||||||
if (xShape->GetDimNum() == 1U) {
|
|
||||||
if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) {
|
|
||||||
OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.",
|
|
||||||
TO_STRING(*xShape).c_str());
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
} else if (xShape->GetDimNum() != DIM_TWO) {
|
|
||||||
OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.",
|
|
||||||
TO_STRING(*xShape).c_str());
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (k < 0) {
|
|
||||||
OP_LOGE(context, "k must be a non-negative number.");
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k)
|
|
||||||
{
|
|
||||||
OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str());
|
|
||||||
OP_LOGD(context, "k is: %ld.", k);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape,
|
|
||||||
const gert::Shape *expertIdxShape, const gert::Shape *outShape)
|
|
||||||
{
|
|
||||||
OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str());
|
|
||||||
OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str());
|
|
||||||
OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context)
|
|
||||||
{
|
|
||||||
OP_LOGD(context, "Begin to do MoeGatingTopKInfershape.");
|
|
||||||
|
|
||||||
// 获取输入shape
|
|
||||||
const gert::Shape *xShape = context->GetInputShape(0);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
|
||||||
gert::Shape *yShape = context->GetOutputShape(0);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
|
||||||
gert::Shape *expertIdxShape = context->GetOutputShape(1);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape);
|
|
||||||
gert::Shape *outShape = context->GetOutputShape(2);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, outShape);
|
|
||||||
|
|
||||||
// 获取attr
|
|
||||||
auto attrs = context->GetAttrs();
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
|
||||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(0);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, kPtr);
|
|
||||||
const int64_t k = *kPtr;
|
|
||||||
ShowInputShapeInfo(context, xShape, k);
|
|
||||||
|
|
||||||
// 参数校验
|
|
||||||
if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) {
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) {
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0);
|
|
||||||
int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1);
|
|
||||||
|
|
||||||
yShape->SetDimNum(DIM_TWO);
|
|
||||||
yShape->SetDim(0U, rows);
|
|
||||||
yShape->SetDim(1U, k);
|
|
||||||
|
|
||||||
expertIdxShape->SetDimNum(DIM_TWO);
|
|
||||||
expertIdxShape->SetDim(0U, rows);
|
|
||||||
expertIdxShape->SetDim(1U, k);
|
|
||||||
|
|
||||||
outShape->SetDimNum(DIM_TWO);
|
|
||||||
outShape->SetDim(0U, rows);
|
|
||||||
outShape->SetDim(1U, expertNum);
|
|
||||||
|
|
||||||
ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape);
|
|
||||||
OP_LOGD(context, "End to do MoeGatingTopKInfershape.");
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context)
|
|
||||||
{
|
|
||||||
OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType.");
|
|
||||||
auto xDtype = context->GetInputDataType(0);
|
|
||||||
context->SetOutputDataType(Y_INDEX, xDtype);
|
|
||||||
context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32);
|
|
||||||
context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT);
|
|
||||||
OP_LOGD(context, "End to do MoeGatingTopKInferDataType.");
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK);
|
|
||||||
} // namespace ops
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_proto.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#include "moe_gating_top_k_proto.h"
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_proto.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
|
||||||
#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
|
||||||
|
|
||||||
#include "graph/operator_reg.h"
|
|
||||||
|
|
||||||
namespace ge {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute renorm(sigmoid) and topk for moe input.
|
|
||||||
*
|
|
||||||
* @par Inputs:
|
|
||||||
* @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count.
|
|
||||||
* @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x.
|
|
||||||
*
|
|
||||||
* @par Outputs:
|
|
||||||
* @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x.
|
|
||||||
The size of the non-1 axis must be the same as that of the corresponding axis of x.
|
|
||||||
The size of the -1 axis must be the same as that of k.
|
|
||||||
* @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y.
|
|
||||||
* @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x.
|
|
||||||
*
|
|
||||||
* @par Attributes:
|
|
||||||
* @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value.
|
|
||||||
* @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1.
|
|
||||||
* @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1.
|
|
||||||
* @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0.
|
|
||||||
* @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0.
|
|
||||||
* @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0.
|
|
||||||
* @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false.
|
|
||||||
* @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0.
|
|
||||||
* @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20.
|
|
||||||
*/
|
|
||||||
REG_OP(MoeGatingTopK)
|
|
||||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
|
||||||
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
|
||||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16}))
|
|
||||||
.OUTPUT(expert_idx, TensorType({DT_INT32}))
|
|
||||||
.OUTPUT(out, TensorType({DT_FLOAT}))
|
|
||||||
.REQUIRED_ATTR(k, Int)
|
|
||||||
.ATTR(k_group, Int, 1)
|
|
||||||
.ATTR(group_count, Int, 1)
|
|
||||||
.ATTR(group_select_mode, Int, 0)
|
|
||||||
.ATTR(renorm, Int, 0)
|
|
||||||
.ATTR(norm_type, Int, 0)
|
|
||||||
.ATTR(out_flag, Bool, false)
|
|
||||||
.ATTR(routed_scaling_factor, Float, 1.0)
|
|
||||||
.ATTR(eps, Float, 1e-20f)
|
|
||||||
.OP_END_FACTORY_REG(MoeGatingTopK)
|
|
||||||
|
|
||||||
} // namespace ge
|
|
||||||
|
|
||||||
#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_
|
|
||||||
@@ -1,580 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/* !
|
|
||||||
* \file moe_gating_top_k_tiling.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#include <cmath>
|
|
||||||
#include "register/op_def_registry.h"
|
|
||||||
#include "exe_graph/runtime/infer_shape_context.h"
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
#include "../tiling_base/tiling_base.h"
|
|
||||||
#include "../tiling_base/tiling_templates_registry.h"
|
|
||||||
#include "platform/platform_info.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#include "error_log.h"
|
|
||||||
#include "moe_gating_top_k_tiling.h"
|
|
||||||
|
|
||||||
// 放在文件顶部,或单独头文件中
|
|
||||||
#ifndef CEIL_ALIGN
|
|
||||||
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef CEIL_DIV
|
|
||||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
||||||
#endif
|
|
||||||
namespace optiling {
|
|
||||||
const static int64_t GROUP_SELECT_MODE_MAX = 0;
|
|
||||||
const static int64_t GROUP_SELECT_MODE_SUM = 1;
|
|
||||||
const static int64_t RENORM_NO = 0;
|
|
||||||
const static int64_t RENORM_L1 = 1;
|
|
||||||
const static int64_t NORM_TYPE_SOFTMAX = 0;
|
|
||||||
const static int64_t NORM_TYPE_SIGMOID = 1;
|
|
||||||
const static int64_t OUT_FLAG_FALSE = 0;
|
|
||||||
const static int64_t OUT_FLAG_TRUE = 1;
|
|
||||||
const static size_t X_INPUT_DIMS = 2;
|
|
||||||
const static size_t BIAS_INPUT_DIMS = 1;
|
|
||||||
const static size_t Y_OUTPUT_DIMS = 2;
|
|
||||||
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
|
|
||||||
const static size_t OUT_OUTPUT_DIMS = 2;
|
|
||||||
const static int64_t MAX_EXPERT_COUNT = 2048;
|
|
||||||
|
|
||||||
const static int64_t X_INPUT_INDEX = 0;
|
|
||||||
const static int64_t BIAS_INPUT_INDEX = 1;
|
|
||||||
const static int64_t Y_OUTPUT_INDEX = 0;
|
|
||||||
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
|
|
||||||
const static int64_t OUT_OUTPUT_INDEX = 2;
|
|
||||||
const static int64_t K_ATTR_INDEX = 0;
|
|
||||||
const static int64_t K_GROUP_ATTR_INDEX = 1;
|
|
||||||
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
|
|
||||||
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
|
|
||||||
const static int64_t RENORM_ATTR_INDEX = 4;
|
|
||||||
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
|
|
||||||
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
|
|
||||||
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
|
|
||||||
const static int64_t EPS_ATTR_INDEX = 8;
|
|
||||||
const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216; // 预留16M空间
|
|
||||||
const static uint32_t DATATYPESIZE_FLOAT = 4;
|
|
||||||
const static bool IS_LARGEST = true;
|
|
||||||
const static bool IS_INITINDEX = false;
|
|
||||||
const static bool IS_REUSESOURCE = false;
|
|
||||||
const static uint64_t WITH_GROUP_CONDITION = 1;
|
|
||||||
const static uint64_t WITHOUT_GROUP_CONDITION = 2;
|
|
||||||
const static uint64_t MAX_IN_GROUP_CONDITION = 3;
|
|
||||||
constexpr int32_t ROW_COUNT_PER_TASK = 1;
|
|
||||||
|
|
||||||
const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0;
|
|
||||||
const static uint64_t TILING_KEY_WITHOUT_GROUP = 1;
|
|
||||||
const static uint64_t TILING_KEY_GENERALIZED = 2;
|
|
||||||
|
|
||||||
inline static int64_t CeilLog4(int64_t x)
|
|
||||||
{
|
|
||||||
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(4))); // 4 for four
|
|
||||||
}
|
|
||||||
|
|
||||||
class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass {
|
|
||||||
public:
|
|
||||||
explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
|
|
||||||
{
|
|
||||||
Reset();
|
|
||||||
}
|
|
||||||
~MoeGatingTopKTilingBase() override = default;
|
|
||||||
|
|
||||||
void Reset(gert::TilingContext *context) override
|
|
||||||
{
|
|
||||||
TilingBaseClass::Reset(context);
|
|
||||||
Reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool IsCapable() override
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
|
||||||
ge::graphStatus GetPlatformInfo() override;
|
|
||||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
|
||||||
ge::graphStatus GetShapeAttrsInfo() override;
|
|
||||||
// 3、计算数据切分TilingData
|
|
||||||
ge::graphStatus DoOpTiling() override;
|
|
||||||
// 4、计算高阶API的TilingData
|
|
||||||
ge::graphStatus DoLibApiTiling() override;
|
|
||||||
// 5、计算TilingKey
|
|
||||||
uint64_t GetTilingKey() const override;
|
|
||||||
// 6、计算Workspace 大小
|
|
||||||
ge::graphStatus GetWorkspaceSize() override;
|
|
||||||
// 7、保存Tiling数据
|
|
||||||
ge::graphStatus PostTiling() override;
|
|
||||||
void Reset();
|
|
||||||
|
|
||||||
private:
|
|
||||||
ge::graphStatus CheckInputShape();
|
|
||||||
ge::graphStatus CheckAttr();
|
|
||||||
ge::graphStatus CheckOutShape();
|
|
||||||
void SplitRows();
|
|
||||||
void CalTmpBufUbSize();
|
|
||||||
|
|
||||||
const gert::Shape *xShape_ = nullptr;
|
|
||||||
const gert::Shape *biasShape_ = nullptr;
|
|
||||||
const gert::Shape *yShape_ = nullptr;
|
|
||||||
const gert::Shape *expertIdxShape_ = nullptr;
|
|
||||||
const gert::Shape *outShape_ = nullptr;
|
|
||||||
|
|
||||||
int64_t rows_ = 0;
|
|
||||||
int64_t expertCount_ = 0;
|
|
||||||
int64_t addBias_ = 0;
|
|
||||||
|
|
||||||
int64_t k_ = 0;
|
|
||||||
int64_t kGroup_ = 0;
|
|
||||||
int64_t groupCount_ = 0;
|
|
||||||
int64_t perGroupExpertCount_ = 0;
|
|
||||||
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
|
|
||||||
int64_t renorm_ = RENORM_NO;
|
|
||||||
int64_t normType_ = NORM_TYPE_SOFTMAX;
|
|
||||||
int64_t outFlag_ = OUT_FLAG_FALSE;
|
|
||||||
float routedScalingFactor_ = 1.0;
|
|
||||||
float eps_ = 1e-20f;
|
|
||||||
|
|
||||||
int64_t inputDtypeSize_;
|
|
||||||
const char *opName_ = "";
|
|
||||||
MoeGatingTopKTilingData moeGatingTopKTilingData_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape()
|
|
||||||
{
|
|
||||||
size_t xDimNum = xShape_->GetDimNum();
|
|
||||||
|
|
||||||
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
|
|
||||||
|
|
||||||
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
// 通过输入获取rows 和 expertCount
|
|
||||||
rows_ = xShape_->GetDim(0);
|
|
||||||
expertCount_ = xShape_->GetDim(1);
|
|
||||||
|
|
||||||
moeGatingTopKTilingData_.set_rowCount(rows_);
|
|
||||||
moeGatingTopKTilingData_.set_expertCount(expertCount_);
|
|
||||||
if (biasShape_ != nullptr) {
|
|
||||||
addBias_ = 1;
|
|
||||||
size_t biasDimNum = biasShape_->GetDimNum();
|
|
||||||
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
|
|
||||||
OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
biasShape_->GetDim(0) != expertCount_,
|
|
||||||
OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_addBias(addBias_);
|
|
||||||
|
|
||||||
OP_CHECK_IF(k_ > expertCount_,
|
|
||||||
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::CheckAttr()
|
|
||||||
{
|
|
||||||
OP_CHECK_IF(
|
|
||||||
expertCount_ > MAX_EXPERT_COUNT,
|
|
||||||
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(kGroup_ > groupCount_,
|
|
||||||
OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
|
|
||||||
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
|
|
||||||
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
|
|
||||||
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
|
|
||||||
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1,
|
|
||||||
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
|
|
||||||
OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
perGroupExpertCount_ = expertCount_ / groupCount_;
|
|
||||||
|
|
||||||
OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_);
|
|
||||||
|
|
||||||
OP_CHECK_IF(perGroupExpertCount_ < 1,
|
|
||||||
OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2,
|
|
||||||
OP_LOGE(context_,
|
|
||||||
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
|
|
||||||
perGroupExpertCount_, groupSelectMode_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_,
|
|
||||||
OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L);
|
|
||||||
OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign);
|
|
||||||
if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) {
|
|
||||||
// 分组场景下才需要校验对齐后的数量
|
|
||||||
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
|
|
||||||
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
|
|
||||||
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_);
|
|
||||||
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
|
|
||||||
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo()
|
|
||||||
{
|
|
||||||
opName_ = context_->GetNodeName();
|
|
||||||
// 获取输入shape信息
|
|
||||||
OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_);
|
|
||||||
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
|
|
||||||
xShape_ = &xShapePtr->GetStorageShape();
|
|
||||||
OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str());
|
|
||||||
|
|
||||||
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
|
|
||||||
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
|
|
||||||
if (biasShape_ != nullptr) {
|
|
||||||
OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str());
|
|
||||||
}
|
|
||||||
// 获取输出shape
|
|
||||||
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
|
|
||||||
yShape_ = &yShapePtr->GetStorageShape();
|
|
||||||
OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str());
|
|
||||||
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
|
|
||||||
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
|
|
||||||
OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str());
|
|
||||||
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr);
|
|
||||||
outShape_ = &outPtr->GetStorageShape();
|
|
||||||
if (outShape_ != nullptr) {
|
|
||||||
OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto x = context_->GetInputDesc(X_INPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
|
|
||||||
auto xDtype = x->GetDataType();
|
|
||||||
OP_CHECK_IF(
|
|
||||||
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
|
|
||||||
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
if (biasShapePtr != nullptr) {
|
|
||||||
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
|
|
||||||
OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str());
|
|
||||||
OP_CHECK_IF((biasDtype != xDtype),
|
|
||||||
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
|
|
||||||
auto yDtype = yDesc->GetDataType();
|
|
||||||
OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str());
|
|
||||||
OP_CHECK_IF((yDtype != xDtype),
|
|
||||||
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
|
|
||||||
auto expertIdDtype = expertIdDesc->GetDataType();
|
|
||||||
OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str());
|
|
||||||
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
|
|
||||||
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc);
|
|
||||||
auto normOutDtype = normOutDesc->GetDataType();
|
|
||||||
OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT),
|
|
||||||
OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
// 获取属性
|
|
||||||
auto attrs = context_->GetAttrs();
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
|
|
||||||
|
|
||||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
|
|
||||||
k_ = *kPtr;
|
|
||||||
OP_LOGI(context_, "Attr k is: %ld", k_);
|
|
||||||
moeGatingTopKTilingData_.set_k(k_);
|
|
||||||
|
|
||||||
OP_LOGI(context_, "Attr k is: %ld ", k_);
|
|
||||||
|
|
||||||
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
|
|
||||||
if (kGroupPtr != nullptr) {
|
|
||||||
kGroup_ = *kGroupPtr;
|
|
||||||
OP_LOGI(context_, "Attr k_group is: %ld", kGroup_);
|
|
||||||
moeGatingTopKTilingData_.set_kGroup(kGroup_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
|
|
||||||
|
|
||||||
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
|
|
||||||
if (groupCountPtr != nullptr) {
|
|
||||||
groupCount_ = *groupCountPtr;
|
|
||||||
OP_LOGI(context_, "Attr group_count is: %ld", groupCount_);
|
|
||||||
moeGatingTopKTilingData_.set_groupCount(groupCount_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
|
|
||||||
|
|
||||||
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
|
|
||||||
if (groupSelectModePtr != nullptr) {
|
|
||||||
groupSelectMode_ = *groupSelectModePtr;
|
|
||||||
OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_);
|
|
||||||
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
|
|
||||||
|
|
||||||
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
|
|
||||||
if (renormPtr != nullptr) {
|
|
||||||
renorm_ = *renormPtr;
|
|
||||||
OP_LOGI(context_, "Attr renorm is: %ld", renorm_);
|
|
||||||
moeGatingTopKTilingData_.set_renorm(renorm_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
|
|
||||||
|
|
||||||
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
|
|
||||||
if (normTypePtr != nullptr) {
|
|
||||||
normType_ = *normTypePtr;
|
|
||||||
OP_LOGI(context_, "Attr norm_type is: %ld", normType_);
|
|
||||||
moeGatingTopKTilingData_.set_normType(normType_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
|
|
||||||
|
|
||||||
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
|
|
||||||
if (outFlagPtr != nullptr) {
|
|
||||||
outFlag_ = (*outFlagPtr) ? 1 : 0;
|
|
||||||
OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_);
|
|
||||||
moeGatingTopKTilingData_.set_outFlag(outFlag_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
|
|
||||||
|
|
||||||
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
|
|
||||||
if (routedScalingFactorPtr != nullptr) {
|
|
||||||
routedScalingFactor_ = *routedScalingFactorPtr;
|
|
||||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_);
|
|
||||||
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
|
|
||||||
|
|
||||||
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
|
|
||||||
if (epsPtr != nullptr) {
|
|
||||||
eps_ = *epsPtr;
|
|
||||||
OP_LOGI(context_, "Attr eps is: %f", eps_);
|
|
||||||
moeGatingTopKTilingData_.set_eps(eps_);
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr eps is: %f ", eps_);
|
|
||||||
|
|
||||||
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
|
|
||||||
OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_);
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo()
|
|
||||||
|
|
||||||
{
|
|
||||||
auto platformInfo = context_->GetPlatformInfo();
|
|
||||||
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
|
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
|
||||||
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
|
|
||||||
uint64_t ubSizePlatForm;
|
|
||||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
|
|
||||||
aicoreParams_.ubSize = ubSizePlatForm;
|
|
||||||
OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize);
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape()
|
|
||||||
{
|
|
||||||
OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str());
|
|
||||||
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
|
|
||||||
xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
|
|
||||||
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
|
|
||||||
outShape_->GetDimNum(), xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
|
|
||||||
xShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
|
|
||||||
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outFlag_ && outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
|
|
||||||
outShape_->GetDim(0), outShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_CHECK_IF((yShape_->GetDim(1) != k_),
|
|
||||||
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
|
|
||||||
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outFlag_ && outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
|
|
||||||
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
|
|
||||||
xShape_->GetDim(1)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingBase::SplitRows()
|
|
||||||
{
|
|
||||||
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
|
|
||||||
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
|
|
||||||
// perCoreRows cannot be 0
|
|
||||||
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
|
|
||||||
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
|
|
||||||
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
|
|
||||||
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
|
|
||||||
int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L));
|
|
||||||
OP_LOGI(context_, "vms count is: %ld", vmsCount);
|
|
||||||
moeGatingTopKTilingData_.set_vmsCount(vmsCount); // 需要归并的轮数
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingBase::CalTmpBufUbSize()
|
|
||||||
|
|
||||||
{
|
|
||||||
|
|
||||||
std::vector<int64_t> shape_vec = {expertCount_};
|
|
||||||
ge::Shape shape(shape_vec);
|
|
||||||
uint32_t maxValue = 0;
|
|
||||||
uint32_t minValue = 0;
|
|
||||||
AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue);
|
|
||||||
|
|
||||||
int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast<int64_t>(sizeof(float));
|
|
||||||
moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast<int64_t>(minValue)));
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling()
|
|
||||||
{
|
|
||||||
|
|
||||||
OP_LOGI(context_, "DoOpTiling: start");
|
|
||||||
auto ret = CheckInputShape();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = CheckOutShape();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = CheckAttr();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
CalTmpBufUbSize();
|
|
||||||
SplitRows();
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling()
|
|
||||||
{
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize()
|
|
||||||
{
|
|
||||||
// 计算workspace大小
|
|
||||||
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingBase::PostTiling()
|
|
||||||
{
|
|
||||||
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
|
|
||||||
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
|
|
||||||
currentWorkspace[0] = workspaceSize_;
|
|
||||||
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
|
|
||||||
context_->GetRawTilingData()->GetCapacity());
|
|
||||||
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t MoeGatingTopKTilingBase::GetTilingKey() const
|
|
||||||
{
|
|
||||||
// DeepSeekV3排序对齐高性能场景
|
|
||||||
if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ &&
|
|
||||||
groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID &&
|
|
||||||
!outFlag_) {
|
|
||||||
// DeepSeekV3排序对齐高性能场景
|
|
||||||
return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF;
|
|
||||||
} else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) {
|
|
||||||
/**
|
|
||||||
* 不分组场景:
|
|
||||||
* 1. 分组数为 1
|
|
||||||
* 2. 分组数等于专家数(每个组只有一个专家)
|
|
||||||
* 3. 选择所有组
|
|
||||||
*/
|
|
||||||
return TILING_KEY_WITHOUT_GROUP;
|
|
||||||
} else {
|
|
||||||
return TILING_KEY_GENERALIZED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingBase::Reset()
|
|
||||||
{
|
|
||||||
opName_ = nullptr;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000);
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_tiling.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
|
||||||
#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
|
|
||||||
#include "../tiling_base/tiling_base.h"
|
|
||||||
#include "../tiling_base/tiling_templates_registry.h"
|
|
||||||
#include "register/op_def_registry.h"
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
#include "register/tilingdata_base.h"
|
|
||||||
#include "tiling/tiling_api.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
#include "platform/platform_infos_def.h"
|
|
||||||
#include "math_util.h"
|
|
||||||
//#include "util/extern_math_util.h"
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData)
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, rowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, expertCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, addBias);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, k);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, kGroup);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, groupCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, renorm);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, normType);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, outFlag);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
|
|
||||||
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
|
|
||||||
TILING_DATA_FIELD_DEF(float, eps);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize);
|
|
||||||
END_TILING_DATA_DEF;
|
|
||||||
REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData)
|
|
||||||
|
|
||||||
BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData)
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, needCoreNum);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, rowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, expertCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, addBias);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, k);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, kGroup);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, groupCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, groupSelectMode);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, renorm);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, normType);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, outFlag);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, vmsCount);
|
|
||||||
TILING_DATA_FIELD_DEF(float, routedScalingFactor);
|
|
||||||
TILING_DATA_FIELD_DEF(float, eps);
|
|
||||||
TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData);
|
|
||||||
END_TILING_DATA_DEF;
|
|
||||||
REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData)
|
|
||||||
struct MoeGatingTopKCompileInfo {};
|
|
||||||
} // namespace optiling
|
|
||||||
#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H
|
|
||||||
@@ -1,521 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/* !
|
|
||||||
* \file moe_gating_top_k_tiling_arch35.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "error_log.h"
|
|
||||||
#include "moe_gating_top_k_tiling.h"
|
|
||||||
#include "register/op_def_registry.h"
|
|
||||||
#include "platform/platform_info.h"
|
|
||||||
#include "../tiling_base/tiling_base.h"
|
|
||||||
#include "../tiling_base/tiling_templates_registry.h"
|
|
||||||
|
|
||||||
#ifndef CEIL_ALIGN
|
|
||||||
#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align))
|
|
||||||
#endif
|
|
||||||
#ifndef CEIL_DIV
|
|
||||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
||||||
#endif
|
|
||||||
namespace optiling {
|
|
||||||
const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000;
|
|
||||||
|
|
||||||
const static int64_t GROUP_SELECT_MODE_MAX = 0;
|
|
||||||
const static int64_t GROUP_SELECT_MODE_SUM = 1;
|
|
||||||
const static int64_t RENORM_NO = 0;
|
|
||||||
const static int64_t RENORM_L1 = 1;
|
|
||||||
const static int64_t NORM_TYPE_SOFTMAX = 0;
|
|
||||||
const static int64_t NORM_TYPE_SIGMOID = 1;
|
|
||||||
const static int64_t OUT_FLAG_FALSE = 0;
|
|
||||||
const static int64_t OUT_FLAG_TRUE = 1;
|
|
||||||
const static size_t X_INPUT_DIMS = 2;
|
|
||||||
const static size_t BIAS_INPUT_DIMS = 1;
|
|
||||||
const static size_t Y_OUTPUT_DIMS = 2;
|
|
||||||
const static size_t EXPERT_IDX_OUTPUY_DIMS = 2;
|
|
||||||
const static size_t OUT_OUTPUT_DIMS = 2;
|
|
||||||
const static int64_t MAX_EXPERT_COUNT = 2048;
|
|
||||||
|
|
||||||
const static int64_t X_INPUT_INDEX = 0;
|
|
||||||
const static int64_t BIAS_INPUT_INDEX = 1;
|
|
||||||
const static int64_t Y_OUTPUT_INDEX = 0;
|
|
||||||
const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1;
|
|
||||||
const static int64_t OUT_OUTPUT_INDEX = 2;
|
|
||||||
const static int64_t K_ATTR_INDEX = 0;
|
|
||||||
const static int64_t K_GROUP_ATTR_INDEX = 1;
|
|
||||||
const static int64_t GROUP_COUNT_ATTR_INDEX = 2;
|
|
||||||
const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3;
|
|
||||||
const static int64_t RENORM_ATTR_INDEX = 4;
|
|
||||||
const static int64_t MRGSORT_SIZE = 4;
|
|
||||||
const static int64_t NORM_TYPE_ATTR_INDEX = 5;
|
|
||||||
const static int64_t OUT_FLAG_ATTR_INDEX = 6;
|
|
||||||
const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7;
|
|
||||||
const static int64_t EPS_ATTR_INDEX = 8;
|
|
||||||
const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast<int64_t>(16 * 1024 * 1024); // 预留16M空间
|
|
||||||
|
|
||||||
|
|
||||||
class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass {
|
|
||||||
public:
|
|
||||||
explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context)
|
|
||||||
{
|
|
||||||
Reset();
|
|
||||||
}
|
|
||||||
~MoeGatingTopKTilingRegbase() override = default;
|
|
||||||
|
|
||||||
void Reset(gert::TilingContext *context) override
|
|
||||||
{
|
|
||||||
TilingBaseClass::Reset(context);
|
|
||||||
Reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool IsCapable() override
|
|
||||||
{
|
|
||||||
if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
|
||||||
ge::graphStatus GetPlatformInfo() override;
|
|
||||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
|
||||||
ge::graphStatus GetShapeAttrsInfo() override;
|
|
||||||
// 3、计算数据切分TilingData
|
|
||||||
ge::graphStatus DoOpTiling() override;
|
|
||||||
// 4、计算高阶API的TilingData
|
|
||||||
ge::graphStatus DoLibApiTiling() override;
|
|
||||||
// 5、计算TilingKey
|
|
||||||
uint64_t GetTilingKey() const override;
|
|
||||||
// 6、计算Workspace 大小
|
|
||||||
ge::graphStatus GetWorkspaceSize() override;
|
|
||||||
// 7、保存Tiling数据
|
|
||||||
ge::graphStatus PostTiling() override;
|
|
||||||
void Reset();
|
|
||||||
|
|
||||||
private:
|
|
||||||
ge::graphStatus CheckInputShape();
|
|
||||||
ge::graphStatus CheckAttr();
|
|
||||||
ge::graphStatus CheckOutShape();
|
|
||||||
void CalTmpBufUbSize();
|
|
||||||
void SplitRows();
|
|
||||||
void Tiling4GatherOutComputeSplitK();
|
|
||||||
|
|
||||||
const gert::Shape *xShape_ = nullptr;
|
|
||||||
const gert::Shape *biasShape_ = nullptr;
|
|
||||||
const gert::Shape *yShape_ = nullptr;
|
|
||||||
const gert::Shape *expertIdxShape_ = nullptr;
|
|
||||||
const gert::Shape *outShape_ = nullptr;
|
|
||||||
|
|
||||||
int64_t rows_;
|
|
||||||
int64_t expertCount_;
|
|
||||||
int64_t addBias_ = 0;
|
|
||||||
|
|
||||||
int64_t k_;
|
|
||||||
int64_t kGroup_ = 1;
|
|
||||||
int64_t groupCount_ = 1;
|
|
||||||
int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX;
|
|
||||||
int64_t renorm_ = RENORM_NO;
|
|
||||||
int64_t normType_ = NORM_TYPE_SOFTMAX;
|
|
||||||
int64_t outFlag_ = OUT_FLAG_FALSE;
|
|
||||||
float routedScalingFactor_ = 1.0;
|
|
||||||
float eps_ = 1e-20f;
|
|
||||||
|
|
||||||
int64_t inputDtypeSize_;
|
|
||||||
const char *opName_ = "";
|
|
||||||
MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_;
|
|
||||||
platform_ascendc::SocVersion socVersion;
|
|
||||||
};
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape()
|
|
||||||
{
|
|
||||||
size_t xDimNum = xShape_->GetDimNum();
|
|
||||||
OP_CHECK_IF(xDimNum != X_INPUT_DIMS,
|
|
||||||
OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
// 通过输入获取rows 和 expertCount
|
|
||||||
rows_ = xShape_->GetDim(0);
|
|
||||||
expertCount_ = xShape_->GetDim(1);
|
|
||||||
moeGatingTopKTilingData_.set_rowCount(rows_);
|
|
||||||
moeGatingTopKTilingData_.set_expertCount(expertCount_);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
expertCount_ > MAX_EXPERT_COUNT,
|
|
||||||
OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
if (biasShape_ != nullptr) {
|
|
||||||
addBias_ = 1;
|
|
||||||
size_t biasDimNum = biasShape_->GetDimNum();
|
|
||||||
OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS,
|
|
||||||
OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_,
|
|
||||||
OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.",
|
|
||||||
biasShape_->GetDim(0), expertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_addBias(addBias_);
|
|
||||||
|
|
||||||
OP_CHECK_IF(k_ > expertCount_,
|
|
||||||
OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr()
|
|
||||||
{
|
|
||||||
OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(expertCount_ % groupCount_ != 0,
|
|
||||||
OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(kGroup_ > groupCount_,
|
|
||||||
OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_,
|
|
||||||
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
|
|
||||||
kGroup_, k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
if (kGroup_ == groupCount_ || groupCount_ == expertCount_) {
|
|
||||||
kGroup_ = 1;
|
|
||||||
groupCount_ = 1;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_kGroup(kGroup_);
|
|
||||||
moeGatingTopKTilingData_.set_groupCount(groupCount_);
|
|
||||||
int64_t groupExpertCount = expertCount_ / groupCount_;
|
|
||||||
int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L);
|
|
||||||
moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_);
|
|
||||||
moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign);
|
|
||||||
|
|
||||||
OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT,
|
|
||||||
OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.",
|
|
||||||
groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(kGroup_ * groupExpertCount < k_,
|
|
||||||
OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.",
|
|
||||||
kGroup_ * groupExpertCount, k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(groupExpertCount < 1,
|
|
||||||
OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX,
|
|
||||||
OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_,
|
|
||||||
GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2,
|
|
||||||
OP_LOGE(context_,
|
|
||||||
"group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.",
|
|
||||||
groupExpertCount, groupSelectMode_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(renorm_ != RENORM_NO,
|
|
||||||
OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID,
|
|
||||||
OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_,
|
|
||||||
NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo()
|
|
||||||
{
|
|
||||||
opName_ = context_->GetNodeName();
|
|
||||||
// 获取输入shape信息
|
|
||||||
auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr);
|
|
||||||
xShape_ = &xShapePtr->GetStorageShape();
|
|
||||||
auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX);
|
|
||||||
biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape();
|
|
||||||
|
|
||||||
// 获取输出shape
|
|
||||||
auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr);
|
|
||||||
yShape_ = &yShapePtr->GetStorageShape();
|
|
||||||
auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr);
|
|
||||||
expertIdxShape_ = &expertIdxPtr->GetStorageShape();
|
|
||||||
auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX);
|
|
||||||
if (outPtr != nullptr) {
|
|
||||||
outShape_ = &outPtr->GetStorageShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto x = context_->GetInputDesc(X_INPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, x);
|
|
||||||
auto xDtype = x->GetDataType();
|
|
||||||
OP_CHECK_IF(
|
|
||||||
(xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16),
|
|
||||||
OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
if (biasShapePtr != nullptr) {
|
|
||||||
auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType();
|
|
||||||
OP_CHECK_IF((biasDtype != xDtype),
|
|
||||||
OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(),
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc);
|
|
||||||
auto yDtype = yDesc->GetDataType();
|
|
||||||
OP_CHECK_IF((yDtype != xDtype),
|
|
||||||
OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(),
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc);
|
|
||||||
auto expertIdDtype = expertIdDesc->GetDataType();
|
|
||||||
OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32),
|
|
||||||
OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
|
|
||||||
// 获取属性
|
|
||||||
auto attrs = context_->GetAttrs();
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, attrs);
|
|
||||||
|
|
||||||
const int64_t *kPtr = attrs->GetAttrPointer<int64_t>(K_ATTR_INDEX);
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr);
|
|
||||||
k_ = *kPtr;
|
|
||||||
moeGatingTopKTilingData_.set_k(k_);
|
|
||||||
OP_LOGI(context_, "Attr k is: %ld ", k_);
|
|
||||||
|
|
||||||
const int64_t *kGroupPtr = attrs->GetAttrPointer<int64_t>(K_GROUP_ATTR_INDEX);
|
|
||||||
if (kGroupPtr != nullptr) {
|
|
||||||
kGroup_ = *kGroupPtr;
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_);
|
|
||||||
|
|
||||||
const int64_t *groupCountPtr = attrs->GetAttrPointer<int64_t>(GROUP_COUNT_ATTR_INDEX);
|
|
||||||
if (groupCountPtr != nullptr) {
|
|
||||||
groupCount_ = *groupCountPtr;
|
|
||||||
}
|
|
||||||
OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_);
|
|
||||||
|
|
||||||
const int64_t *groupSelectModePtr = attrs->GetAttrPointer<int64_t>(GROUP_SELECT_MODE_ATTR_INDEX);
|
|
||||||
if (groupSelectModePtr != nullptr) {
|
|
||||||
groupSelectMode_ = *groupSelectModePtr;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_);
|
|
||||||
OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_);
|
|
||||||
|
|
||||||
const int64_t *renormPtr = attrs->GetAttrPointer<int64_t>(RENORM_ATTR_INDEX);
|
|
||||||
if (renormPtr != nullptr) {
|
|
||||||
renorm_ = *renormPtr;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_renorm(renorm_);
|
|
||||||
OP_LOGI(context_, "Attr renorm is: %ld ", renorm_);
|
|
||||||
|
|
||||||
const int64_t *normTypePtr = attrs->GetAttrPointer<int64_t>(NORM_TYPE_ATTR_INDEX);
|
|
||||||
if (normTypePtr != nullptr) {
|
|
||||||
normType_ = *normTypePtr;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_normType(normType_);
|
|
||||||
OP_LOGI(context_, "Attr norm_type is: %ld ", normType_);
|
|
||||||
|
|
||||||
const bool *outFlagPtr = attrs->GetAttrPointer<bool>(OUT_FLAG_ATTR_INDEX);
|
|
||||||
if (outFlagPtr != nullptr) {
|
|
||||||
outFlag_ = (*outFlagPtr) ? 1 : 0;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_outFlag(outFlag_);
|
|
||||||
OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_);
|
|
||||||
|
|
||||||
const float *routedScalingFactorPtr = attrs->GetAttrPointer<float>(ROUTED_SCALING_FACTOR_ATTR_INDEX);
|
|
||||||
if (routedScalingFactorPtr != nullptr) {
|
|
||||||
routedScalingFactor_ = *routedScalingFactorPtr;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_);
|
|
||||||
OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_);
|
|
||||||
|
|
||||||
const float *epsPtr = attrs->GetAttrPointer<float>(EPS_ATTR_INDEX);
|
|
||||||
if (epsPtr != nullptr) {
|
|
||||||
eps_ = *epsPtr;
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_eps(eps_);
|
|
||||||
OP_LOGI(context_, "Attr eps is: %f ", eps_);
|
|
||||||
|
|
||||||
auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX);
|
|
||||||
if (outFlag_ && outDesc != nullptr) {
|
|
||||||
auto outDtype = outDesc->GetDataType();
|
|
||||||
OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT),
|
|
||||||
OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.",
|
|
||||||
ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
inputDtypeSize_ = static_cast<int64_t>(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType()));
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo()
|
|
||||||
{
|
|
||||||
auto platformInfo = context_->GetPlatformInfo();
|
|
||||||
OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED);
|
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
|
||||||
aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
|
|
||||||
socVersion = ascendcPlatform.GetSocVersion();
|
|
||||||
uint64_t ubSizePlatForm;
|
|
||||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
|
|
||||||
aicoreParams_.ubSize = ubSizePlatForm;
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape()
|
|
||||||
{
|
|
||||||
OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(),
|
|
||||||
xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.",
|
|
||||||
expertIdxShape_->GetDimNum(), xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()),
|
|
||||||
OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.",
|
|
||||||
outShape_->GetDimNum(), xShape_->GetDimNum()),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0),
|
|
||||||
xShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.",
|
|
||||||
expertIdxShape_->GetDim(0), xShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outFlag_ && outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)),
|
|
||||||
OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.",
|
|
||||||
outShape_->GetDim(0), outShape_->GetDim(0)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_CHECK_IF((yShape_->GetDim(1) != k_),
|
|
||||||
OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_),
|
|
||||||
OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
if (outFlag_ && outShape_ != nullptr) {
|
|
||||||
OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)),
|
|
||||||
OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1),
|
|
||||||
xShape_->GetDim(1)),
|
|
||||||
return ge::GRAPH_FAILED);
|
|
||||||
}
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() {
|
|
||||||
std::vector<int64_t> shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()};
|
|
||||||
ge::Shape softmaxShape(shape_vec);
|
|
||||||
|
|
||||||
uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true);
|
|
||||||
AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingRegbase::SplitRows()
|
|
||||||
{
|
|
||||||
int64_t perCoreRows = CEIL_DIV(rows_, static_cast<int64_t>(aicoreParams_.blockDim));
|
|
||||||
int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows);
|
|
||||||
if (perCoreRows == 0) {
|
|
||||||
OP_LOGE(context_, "perCoreRows can't be 0.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows;
|
|
||||||
moeGatingTopKTilingData_.set_needCoreNum(needCoreNum);
|
|
||||||
moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows);
|
|
||||||
moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows);
|
|
||||||
|
|
||||||
int64_t vmsCount = 0;
|
|
||||||
if (kGroup_ > MRGSORT_SIZE) {
|
|
||||||
int64_t index = MRGSORT_SIZE;
|
|
||||||
while (index < kGroup_) {
|
|
||||||
index = index * MRGSORT_SIZE;
|
|
||||||
vmsCount++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
moeGatingTopKTilingData_.set_vmsCount(vmsCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling()
|
|
||||||
{
|
|
||||||
auto ret = CheckInputShape();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = CheckAttr();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = CheckOutShape();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
CalTmpBufUbSize();
|
|
||||||
SplitRows();
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling()
|
|
||||||
{
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize()
|
|
||||||
{
|
|
||||||
// 计算workspace大小
|
|
||||||
workspaceSize_ = DEFAULT_WORKSPACE_SIZE;
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling()
|
|
||||||
{
|
|
||||||
context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum());
|
|
||||||
size_t *currentWorkspace = context_->GetWorkspaceSizes(1);
|
|
||||||
currentWorkspace[0] = workspaceSize_;
|
|
||||||
moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(),
|
|
||||||
context_->GetRawTilingData()->GetCapacity());
|
|
||||||
context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize());
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const
|
|
||||||
{
|
|
||||||
return MOE_GATING_TOP_K_REGBASE_TILING_KEY;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MoeGatingTopKTilingRegbase::Reset()
|
|
||||||
{
|
|
||||||
opName_ = nullptr;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000);
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/* !
|
|
||||||
* \file moe_gating_top_k_tiling_base.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#include "moe_gating_top_k_tiling.h"
|
|
||||||
#include "register/op_def_registry.h"
|
|
||||||
#include "../tiling_base/tiling_base.h"
|
|
||||||
#include "../tiling_base/tiling_templates_registry.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
#include "kernel_tiling/kernel_tiling.h"
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context)
|
|
||||||
{
|
|
||||||
return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context)
|
|
||||||
{
|
|
||||||
(void)context;
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
IMPL_OP_OPTILING(MoeGatingTopK)
|
|
||||||
.Tiling(TilingForMoeGatingTopK)
|
|
||||||
.TilingParse<MoeGatingTopKCompileInfo>(TilingPrepareForMoeGatingTopK);
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file common.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#ifndef MOE_GATING_TOP_K_COMMON_H
|
|
||||||
#define MOE_GATING_TOP_K_COMMON_H
|
|
||||||
|
|
||||||
#include "kernel_operator.h"
|
|
||||||
|
|
||||||
namespace MoeGatingTopK {
|
|
||||||
using namespace AscendC;
|
|
||||||
const float MIN_FP32 = *(float *)(&F32_NEG_INF);
|
|
||||||
constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040
|
|
||||||
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
|
|
||||||
constexpr int64_t BLOCK_BYTES = 32;
|
|
||||||
constexpr int64_t REPEAT_BYTES = 256;
|
|
||||||
constexpr int64_t REPEAT_BLOCKS = 8;
|
|
||||||
|
|
||||||
constexpr int32_t CONSTANT_TWO = 2;
|
|
||||||
constexpr int32_t CONSTANT_THREE = 3;
|
|
||||||
constexpr int32_t CONSTANT_FOUR = 4;
|
|
||||||
constexpr int32_t CONSTANT_EIGHT = 8;
|
|
||||||
|
|
||||||
constexpr int64_t MERGE_LIST_TWO = 2;
|
|
||||||
constexpr int64_t MERGE_LIST_THREE = 3;
|
|
||||||
constexpr int64_t MERGE_LIST_FOUR = 4;
|
|
||||||
|
|
||||||
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
|
|
||||||
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
|
|
||||||
|
|
||||||
constexpr int64_t NORM_TYPE_SOFTMAX = 0;
|
|
||||||
constexpr int64_t NORM_TYPE_SIGMOID = 1;
|
|
||||||
|
|
||||||
__aicore__ inline int64_t Ceil(int64_t a, int64_t b)
|
|
||||||
{
|
|
||||||
if (b == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes)
|
|
||||||
{
|
|
||||||
if (bytes == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes)
|
|
||||||
{
|
|
||||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline T Min(T a, T b)
|
|
||||||
{
|
|
||||||
return a > b ? b : a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline T Max(T a, T b)
|
|
||||||
{
|
|
||||||
return a < b ? b : a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
|
||||||
__aicore__ inline T1 CeilDiv(T1 x, T2 y)
|
|
||||||
{
|
|
||||||
if (y != 0 && x != 0) {
|
|
||||||
const T1 quotient = x / y;
|
|
||||||
return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient;
|
|
||||||
}
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace MoeGatingTopK
|
|
||||||
#endif // MOE_GATING_TOP_K_COMMON_H
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "toolchain/slog.h"
|
|
||||||
|
|
||||||
#define OP_LOGI(opname, ...)
|
|
||||||
#define OP_LOGW(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGD(opname, ...)
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
|
||||||
do { \
|
|
||||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
|
|
||||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
|
||||||
do { \
|
|
||||||
if (cond) { \
|
|
||||||
log_func; \
|
|
||||||
expr; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
|
||||||
do { \
|
|
||||||
if ((ptr) == nullptr) { \
|
|
||||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
|
||||||
return ge::GRAPH_FAILED; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
|
|
||||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "moe_gating_top_k_e_k_fullload.h"
|
|
||||||
#include "moe_gating_top_k_without_group.h"
|
|
||||||
#include "moe_gating_top_k_generalized.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
|
|
||||||
#define TILING_KEY_PER_GROUP_COUNT_32 0
|
|
||||||
#define TILING_KEY_WITHOUT_GROUP 1
|
|
||||||
#define TILING_KEY_GENERALIZED 2
|
|
||||||
|
|
||||||
using namespace AscendC;
|
|
||||||
using namespace MoeGatingTopK;
|
|
||||||
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
|
||||||
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
|
|
||||||
{
|
|
||||||
|
|
||||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
|
|
||||||
if (g_coreType == AIC) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling);
|
|
||||||
if (workspace == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GM_ADDR userWS = GetUserWorkspace(workspace);
|
|
||||||
if (userWS == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const MoeGatingTopKTilingData *__restrict t = &tilingData;
|
|
||||||
TPipe tPipe;
|
|
||||||
if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) {
|
|
||||||
MoeGatingTopKEKFullload<DTYPE_X> op;
|
|
||||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
|
||||||
op.Process();
|
|
||||||
} else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) {
|
|
||||||
MoeGatingTopKWithoutGroup<DTYPE_X> op;
|
|
||||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
|
||||||
op.Process();
|
|
||||||
} else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) {
|
|
||||||
MoeGatingTopKGenerlized<DTYPE_X> op;
|
|
||||||
op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe);
|
|
||||||
op.Process();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_apt.cpp
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "arch35/moe_gating_top_k_regbase.h"
|
|
||||||
using namespace AscendC;
|
|
||||||
using namespace MoeGatingTopK;
|
|
||||||
|
|
||||||
#define TILING_KEY_REGBASE 10000
|
|
||||||
|
|
||||||
extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
|
||||||
GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling)
|
|
||||||
{
|
|
||||||
if (g_coreType == AIC) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (workspace == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GM_ADDR userWS = GetUserWorkspace(workspace);
|
|
||||||
if (userWS == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling);
|
|
||||||
const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in;
|
|
||||||
TPipe tPipe;
|
|
||||||
if (TILING_KEY_IS(TILING_KEY_REGBASE)) {
|
|
||||||
MoeGatingTopKRegbase<DTYPE_X> op;
|
|
||||||
op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe);
|
|
||||||
op.Process();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,404 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_e_k_fullload.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
|
||||||
#define MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
|
||||||
#include "kernel_operator.h"
|
|
||||||
#include "common.h"
|
|
||||||
namespace MoeGatingTopK {
|
|
||||||
using namespace AscendC;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class MoeGatingTopKEKFullload {
|
|
||||||
public:
|
|
||||||
__aicore__ inline MoeGatingTopKEKFullload(){};
|
|
||||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
|
||||||
__aicore__ inline void Process();
|
|
||||||
|
|
||||||
private:
|
|
||||||
__aicore__ inline void CopyInBias();
|
|
||||||
__aicore__ inline void CopyInX(int64_t progress);
|
|
||||||
__aicore__ inline void ComputeX();
|
|
||||||
__aicore__ inline void SortInGroup();
|
|
||||||
__aicore__ inline void SelectTopKGroupIndex();
|
|
||||||
__aicore__ inline void SelectTopKExpertIdx();
|
|
||||||
__aicore__ inline void SelectTopKExpertScore();
|
|
||||||
__aicore__ inline void CopyOut(int64_t progress);
|
|
||||||
|
|
||||||
private:
|
|
||||||
TPipe *pipe_;
|
|
||||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
|
||||||
TBuf<TPosition::VECCALC> biasInQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
|
||||||
|
|
||||||
TQue<QuePosition::VECOUT, 1> xBiasQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> xSigmoidQueue_;
|
|
||||||
TQue<QuePosition::VECIN, 1> sigmoidTmpQueue_;
|
|
||||||
TQue<QuePosition::VECIN, 1> sortedInGroupQueue_;
|
|
||||||
TQue<QuePosition::VECIN, 1> sortedGroupQueue_;
|
|
||||||
TBuf<TPosition::VECCALC> calcTmpBuffer_;
|
|
||||||
|
|
||||||
GlobalTensor<T> xGm_;
|
|
||||||
GlobalTensor<T> biasGm_;
|
|
||||||
GlobalTensor<T> yGm_;
|
|
||||||
GlobalTensor<int32_t> expertIdxGm_;
|
|
||||||
GlobalTensor<T> outGm_;
|
|
||||||
|
|
||||||
int64_t blockIdx_;
|
|
||||||
int64_t perCoreRowCount_;
|
|
||||||
int64_t curCoreRowCount_;
|
|
||||||
int64_t expertCount_;
|
|
||||||
bool addBias_;
|
|
||||||
int64_t k_;
|
|
||||||
int64_t kGroup_;
|
|
||||||
int64_t groupCount_;
|
|
||||||
int64_t groupSelectMode_;
|
|
||||||
int64_t renorm_;
|
|
||||||
int64_t normType_;
|
|
||||||
int64_t outFlag_;
|
|
||||||
float routedScalingFactor_;
|
|
||||||
float eps_;
|
|
||||||
|
|
||||||
int64_t expertCountAlign_;
|
|
||||||
int64_t kAlign_;
|
|
||||||
int64_t perGroupExpertCount_;
|
|
||||||
|
|
||||||
const MoeGatingTopKTilingData *tilingData_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInBias()
|
|
||||||
{
|
|
||||||
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE, expertCount_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyInX(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
|
||||||
dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
|
||||||
expertCount_);
|
|
||||||
}
|
|
||||||
|
|
||||||
xInQueue_.EnQue(xInLocalTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::ComputeX()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
|
||||||
LocalTensor<float> xBiasTensor = xBiasQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<float> biasTensor = biasInQueue_.Get<float>();
|
|
||||||
LocalTensor<uint8_t> sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor<uint8_t>(); // 临时空间可以复用
|
|
||||||
Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
if (addBias_) {
|
|
||||||
Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_);
|
|
||||||
} else {
|
|
||||||
Adds(xBiasTensor, xSigmoidTensor, static_cast<float>(0), expertCount_);
|
|
||||||
}
|
|
||||||
|
|
||||||
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
|
|
||||||
xBiasQueue_.EnQue<float>(xBiasTensor);
|
|
||||||
xInQueue_.FreeTensor(xInLocalTensor);
|
|
||||||
sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SortInGroup()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xBiasTensor = xBiasQueue_.DeQue<float>();
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.AllocTensor<float>(); // 组内排序的结果, 后续归并需要
|
|
||||||
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>(); // 用于存储排序时的索引
|
|
||||||
ArithProgression(indexTensor.ReinterpretCast<int32_t>(), 0, 1, expertCount_); // 生成组索引0 1 2 ......
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序
|
|
||||||
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
|
|
||||||
xBiasQueue_.FreeTensor(xBiasTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKGroupIndex()
|
|
||||||
{
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
|
|
||||||
LocalTensor<uint32_t> indexTensor = calcTmpBuffer_.Get<uint32_t>();
|
|
||||||
LocalTensor<float> top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor<float>(); // 这个临时空间可以复用
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
|
|
||||||
indexTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
|
|
||||||
indexTensor.SetValue(1, static_cast<uint32_t>(0));
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
GatherMaskParams gatherMaskParams;
|
|
||||||
gatherMaskParams.repeatTimes = 8;
|
|
||||||
gatherMaskParams.src0BlockStride = 1;
|
|
||||||
gatherMaskParams.src0RepeatStride = 8;
|
|
||||||
gatherMaskParams.src1RepeatStride = 0;
|
|
||||||
GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast<uint32_t>(64),
|
|
||||||
gatherMaskParams, rsvdCnt);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
LocalTensor<float> groupTop2SumTensor = top2ValueInGroupTensor;
|
|
||||||
PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1,
|
|
||||||
1); // 计算每个组内最大的两个数之和
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
LocalTensor<uint32_t> groupIndexTensor = indexTensor;
|
|
||||||
ArithProgression(groupIndexTensor.ReinterpretCast<int32_t>(), 0, 1, groupCount_); // 生成组索引
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
// 用最小值补到32个数
|
|
||||||
int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX << groupCount_;
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
// 排序,将kgroup选出来
|
|
||||||
LocalTensor<float> sortedGroupTensor = sortedGroupQueue_.AllocTensor<float>();
|
|
||||||
Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1);
|
|
||||||
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
LocalTensor<int32_t> sortedGroupIndexTensor = indexTensor.ReinterpretCast<int32_t>();
|
|
||||||
// 提取组序号
|
|
||||||
uint8_t src1Pattern = 2; // 内置固定模式
|
|
||||||
GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
|
||||||
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
|
|
||||||
|
|
||||||
// 需要将组排序(这里是降序,所以下mrgsor的时候反着取,3、2、1、0)
|
|
||||||
Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX << kGroup_;
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast<uint32_t>(), 1);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
src1Pattern = 1;
|
|
||||||
GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast<uint32_t>(0), {1, 1, 0, 0},
|
|
||||||
rsvdCnt);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_);
|
|
||||||
|
|
||||||
sortedGroupQueue_.FreeTensor(sortedGroupTensor);
|
|
||||||
sortedInGroupQueue_.EnQue<float>(sortedInGroupTensor);
|
|
||||||
sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertIdx()
|
|
||||||
{
|
|
||||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.AllocTensor<int32_t>();
|
|
||||||
LocalTensor<int32_t> topKGroupIndexTensor = calcTmpBuffer_.Get<int32_t>();
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupQueue_.DeQue<float>();
|
|
||||||
LocalTensor<float> sortedExpertTensor = xInQueue_.AllocTensor<float>();
|
|
||||||
AscendC::MrgSort4Info params;
|
|
||||||
params.elementLengths[0] = k_;
|
|
||||||
params.elementLengths[1] = k_;
|
|
||||||
params.elementLengths[2] = k_;
|
|
||||||
params.elementLengths[3] = k_;
|
|
||||||
params.ifExhaustedSuspension = true;
|
|
||||||
params.validBit = 0b1111;
|
|
||||||
params.repeatTimes = 1;
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2;
|
|
||||||
int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2;
|
|
||||||
int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2;
|
|
||||||
int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2;
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
AscendC::MrgSortSrcList<float> srcList;
|
|
||||||
srcList.src1 = sortedInGroupTensor[listOffset1];
|
|
||||||
srcList.src2 = sortedInGroupTensor[listOffset2];
|
|
||||||
srcList.src3 = sortedInGroupTensor[listOffset3];
|
|
||||||
srcList.src4 = sortedInGroupTensor[listOffset4];
|
|
||||||
MrgSort<float>(sortedExpertTensor, srcList, params);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
uint8_t src1Pattern = 2; // 内置固定模式
|
|
||||||
GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
|
||||||
static_cast<uint32_t>(0), {1, 1, 0, 0}, rsvdCnt);
|
|
||||||
xInQueue_.FreeTensor(sortedExpertTensor);
|
|
||||||
expertIdxOutQueue_.EnQue(expertIdxTensor);
|
|
||||||
sortedInGroupQueue_.FreeTensor(sortedInGroupTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::SelectTopKExpertScore()
|
|
||||||
{
|
|
||||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
|
|
||||||
LocalTensor<int32_t> expertByteIdxTensor = calcTmpBuffer_.Get<int32_t>();
|
|
||||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
|
|
||||||
LocalTensor<T> yTensor = yOutQueue_.AllocTensor<T>();
|
|
||||||
LocalTensor<float> yOutTensor;
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
yOutTensor = yTensor.template ReinterpretCast<float>()[kAlign_];
|
|
||||||
} else {
|
|
||||||
yOutTensor = yTensor;
|
|
||||||
}
|
|
||||||
Muls(expertByteIdxTensor, expertIdxTensor, static_cast<int32_t>(sizeof(float)), k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast<uint32_t>(),
|
|
||||||
static_cast<uint32_t>(0), k_);
|
|
||||||
|
|
||||||
LocalTensor<float> calTensor = calcTmpBuffer_.Get<float>();
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_);
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float sumValue = calTensor.GetValue(0) + eps_;
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Duplicate(calTensor, sumValue, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Div(yOutTensor, yOutTensor, calTensor, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_);
|
|
||||||
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_);
|
|
||||||
}
|
|
||||||
|
|
||||||
xSigmoidQueue_.EnQue<float>(xSigmoidTensor);
|
|
||||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxTensor);
|
|
||||||
yOutQueue_.EnQue(yTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::CopyOut(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
|
||||||
LocalTensor<int32_t> expertIdxTensor = expertIdxOutQueue_.DeQue<int32_t>();
|
|
||||||
LocalTensor<float> xSigmoidTensor = xSigmoidQueue_.DeQue<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
|
||||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
|
||||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams);
|
|
||||||
xSigmoidQueue_.FreeTensor(xSigmoidTensor);
|
|
||||||
expertIdxOutQueue_.FreeTensor(expertIdxTensor);
|
|
||||||
yOutQueue_.FreeTensor(yOutTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
|
||||||
GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
|
||||||
{
|
|
||||||
tilingData_ = tilingData;
|
|
||||||
pipe_ = tPipe;
|
|
||||||
blockIdx_ = GetBlockIdx();
|
|
||||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
if (blockIdx_ == GetBlockNum() - 1) {
|
|
||||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
|
||||||
} else {
|
|
||||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
}
|
|
||||||
expertCount_ = tilingData_->expertCount;
|
|
||||||
addBias_ = tilingData_->addBias == 1;
|
|
||||||
k_ = tilingData_->k;
|
|
||||||
kGroup_ = tilingData_->kGroup;
|
|
||||||
groupCount_ = tilingData_->groupCount;
|
|
||||||
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
|
|
||||||
routedScalingFactor_ = tilingData_->routedScalingFactor;
|
|
||||||
eps_ = tilingData_->eps;
|
|
||||||
|
|
||||||
expertCountAlign_ = Align(expertCount_, sizeof(float));
|
|
||||||
kAlign_ = Align(expertCount_, sizeof(float));
|
|
||||||
|
|
||||||
// init input gm buf
|
|
||||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
|
||||||
|
|
||||||
// init output gm buf
|
|
||||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
|
|
||||||
// init que
|
|
||||||
pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float)));
|
|
||||||
pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t)));
|
|
||||||
pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float)));
|
|
||||||
pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2);
|
|
||||||
pipe_->InitBuffer(sortedGroupQueue_, 2,
|
|
||||||
(groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM *
|
|
||||||
sizeof(float) * 2);
|
|
||||||
|
|
||||||
pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKEKFullload<T>::Process()
|
|
||||||
{
|
|
||||||
CopyInBias();
|
|
||||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
|
||||||
CopyInX(row);
|
|
||||||
ComputeX();
|
|
||||||
SortInGroup();
|
|
||||||
SelectTopKGroupIndex();
|
|
||||||
SelectTopKExpertIdx();
|
|
||||||
SelectTopKExpertScore();
|
|
||||||
CopyOut(row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace MoeGatingTopK
|
|
||||||
#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H
|
|
||||||
@@ -1,669 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_generalized.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
|
||||||
#define MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
|
||||||
#include "kernel_operator.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "kernel_utils.h"
|
|
||||||
namespace MoeGatingTopK {
|
|
||||||
using namespace AscendC;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class MoeGatingTopKGenerlized {
|
|
||||||
public:
|
|
||||||
__aicore__ inline MoeGatingTopKGenerlized(){};
|
|
||||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
|
||||||
__aicore__ inline void Process();
|
|
||||||
|
|
||||||
private:
|
|
||||||
__aicore__ inline void CopyInBiasAndInitExpertId();
|
|
||||||
__aicore__ inline void CopyInX(int64_t progress);
|
|
||||||
__aicore__ inline void ComputeX();
|
|
||||||
__aicore__ inline void CopuOutXNorm(int64_t row);
|
|
||||||
__aicore__ inline void SortInGroup();
|
|
||||||
__aicore__ inline void SelectTopKGroupIndex();
|
|
||||||
__aicore__ inline void SelectTopKExpertIdx();
|
|
||||||
__aicore__ inline void SelectTopKExpertScore();
|
|
||||||
__aicore__ inline void CumputeActualTopKExpertId();
|
|
||||||
__aicore__ inline void CopyOut(int64_t row);
|
|
||||||
|
|
||||||
private:
|
|
||||||
TPipe *pipe_;
|
|
||||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
|
||||||
|
|
||||||
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
|
|
||||||
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
|
|
||||||
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
|
|
||||||
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
|
|
||||||
TBuf<TPosition::VECCALC> sortedInGroupBuf_; // 存放组内排序后的结果
|
|
||||||
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
|
|
||||||
TBuf<TPosition::VECCALC> sortedGroupIndexBuf_;
|
|
||||||
TBuf<TPosition::VECCALC> calcTmpBuf_;
|
|
||||||
|
|
||||||
GlobalTensor<T> xGm_;
|
|
||||||
GlobalTensor<T> biasGm_;
|
|
||||||
GlobalTensor<T> yGm_;
|
|
||||||
GlobalTensor<int32_t> expertIdxGm_;
|
|
||||||
GlobalTensor<float> outGm_;
|
|
||||||
|
|
||||||
int64_t blockIdx_ = 0;
|
|
||||||
int64_t perCoreRowCount_ = 0;
|
|
||||||
int64_t curCoreRowCount_ = 0;
|
|
||||||
int64_t expertCount_ = 0;
|
|
||||||
bool addBias_ = false;
|
|
||||||
int64_t k_ = 0;
|
|
||||||
int64_t kGroup_ = 0;
|
|
||||||
int64_t groupCount_ = 0;
|
|
||||||
int64_t groupCountAlign_ = 0;
|
|
||||||
int64_t perGroupExpertCount_ = 0;
|
|
||||||
int64_t perGroupExpertCountAlign_ = 0;
|
|
||||||
int64_t groupSelectMode_ = 0;
|
|
||||||
int64_t renorm_ = 0;
|
|
||||||
int64_t normType_ = 0;
|
|
||||||
int64_t outFlag_ = 0;
|
|
||||||
|
|
||||||
int64_t expertCountAlign_ = 0;
|
|
||||||
int64_t kAlign_ = 0;
|
|
||||||
bool isAlign_ = false;
|
|
||||||
|
|
||||||
const MoeGatingTopKTilingData *tilingData_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInBiasAndInitExpertId()
|
|
||||||
{
|
|
||||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
|
||||||
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
|
|
||||||
DataCopyExtParams dataCopyParams;
|
|
||||||
dataCopyParams.blockCount = groupCount_;
|
|
||||||
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
|
|
||||||
dataCopyParams.srcStride = 0;
|
|
||||||
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
|
|
||||||
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if (addBias_) {
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
|
||||||
expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isAlign_) {
|
|
||||||
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
|
|
||||||
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(biasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
|
|
||||||
perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCountAlign_);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyInX(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
|
||||||
DataCopyExtParams dataCopyParams;
|
|
||||||
dataCopyParams.blockCount = groupCount_;
|
|
||||||
dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T);
|
|
||||||
dataCopyParams.srcStride = 0;
|
|
||||||
dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES;
|
|
||||||
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
|
||||||
dataCopyPadParams);
|
|
||||||
}
|
|
||||||
xInQueue_.EnQue(xInLocalTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::ComputeX()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
|
||||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
|
||||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
|
||||||
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
|
||||||
expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM;
|
|
||||||
int duplicateIndex = perGroupExpertCount_ - duplicateNum;
|
|
||||||
if (!isAlign_ && duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(xInLocalTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1,
|
|
||||||
(perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
if (normType_ == 1) { // sigmoid
|
|
||||||
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
|
|
||||||
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
else if (normType_ == 0) { // softmax
|
|
||||||
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
|
|
||||||
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
|
|
||||||
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_);
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float maxValue = reduceValueTensor.GetValue(0);
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Exp(xNormTensor, xNormTensor, expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_);
|
|
||||||
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float sumValue = reduceValueTensor.GetValue(0);
|
|
||||||
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
if (addBias_) {
|
|
||||||
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_);
|
|
||||||
} else {
|
|
||||||
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isAlign_ && duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex],
|
|
||||||
FLOAT32_NEG_INF, // MIN_FP32,
|
|
||||||
mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES);
|
|
||||||
}
|
|
||||||
xInQueue_.FreeTensor(xInLocalTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopuOutXNorm(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
|
|
||||||
outOutQueue_.EnQue<float>(outOutTensor);
|
|
||||||
outOutTensor = outOutQueue_.DeQue<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{
|
|
||||||
static_cast<uint16_t>(groupCount_), static_cast<uint32_t>(perGroupExpertCount_ * sizeof(float)),
|
|
||||||
static_cast<uint32_t>((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0};
|
|
||||||
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
|
|
||||||
outOutQueue_.FreeTensor(outOutTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SortInGroup()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
|
||||||
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
|
||||||
LocalTensor<float> tmpLocal = calcTmpBuf_.Get<float>();
|
|
||||||
if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_);
|
|
||||||
} else {
|
|
||||||
for (int64_t group = 0; group < groupCount_; group++) {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Sort<float, true>(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO],
|
|
||||||
xNormWithBiasTensor[group * perGroupExpertCountAlign_],
|
|
||||||
expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal,
|
|
||||||
perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKGroupIndex()
|
|
||||||
{
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
|
||||||
LocalTensor<float> valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, 0);
|
|
||||||
LocalTensor<uint32_t> maskTensor =
|
|
||||||
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float));
|
|
||||||
LocalTensor<float> topValueInGroupTensor =
|
|
||||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float));
|
|
||||||
LocalTensor<uint32_t> groupIndex =
|
|
||||||
calcTmpBuf_.GetWithOffset<uint32_t>(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float));
|
|
||||||
LocalTensor<float> sortedTopValue =
|
|
||||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float));
|
|
||||||
LocalTensor<float> sortTmp =
|
|
||||||
calcTmpBuf_.GetWithOffset<float>(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float));
|
|
||||||
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
if (groupSelectMode_ == 1) { // top2 sum
|
|
||||||
// 提取每组组前两个元素
|
|
||||||
maskTensor.SetValue(0, static_cast<uint32_t>(5)); // b0101
|
|
||||||
maskTensor.SetValue(1, static_cast<uint32_t>(0));
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
|
|
||||||
GatherMaskParams gatherMaskParams;
|
|
||||||
gatherMaskParams.repeatTimes = groupCount_;
|
|
||||||
gatherMaskParams.src0BlockStride = 1;
|
|
||||||
gatherMaskParams.src0RepeatStride =
|
|
||||||
Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES);
|
|
||||||
gatherMaskParams.src1RepeatStride = 0;
|
|
||||||
GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true,
|
|
||||||
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
// 计算每个组前两个数的和
|
|
||||||
PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor,
|
|
||||||
Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1,
|
|
||||||
CONSTANT_EIGHT); // 计算每个组内最大的两个数之和
|
|
||||||
} else {
|
|
||||||
maskTensor.SetValue(0, static_cast<uint32_t>(1)); // b0101
|
|
||||||
maskTensor.SetValue(1, static_cast<uint32_t>(0));
|
|
||||||
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
GatherMaskParams gatherMaskParams;
|
|
||||||
gatherMaskParams.repeatTimes = groupCount_;
|
|
||||||
gatherMaskParams.src0BlockStride = 1;
|
|
||||||
gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32);
|
|
||||||
gatherMaskParams.src1RepeatStride = 0;
|
|
||||||
GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true,
|
|
||||||
static_cast<uint32_t>(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt);
|
|
||||||
}
|
|
||||||
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
// 生成组索引
|
|
||||||
ArithProgression(groupIndex.ReinterpretCast<int32_t>(), static_cast<int32_t>(0), static_cast<int32_t>(1),
|
|
||||||
groupCount_); // 生成组索引
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM;
|
|
||||||
int duplicateIndex = groupCount_ - duplicateNum;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(topValueInGroupTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1,
|
|
||||||
REPEAT_BLOCKS);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
// 排序
|
|
||||||
Sort<float, true>(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32));
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
// 提取组序号
|
|
||||||
uint8_t src1Pattern = 2; // 内置固定模式
|
|
||||||
GatherMask(groupIndex, sortedTopValue.template ReinterpretCast<uint32_t>(), src1Pattern, false,
|
|
||||||
static_cast<uint32_t>(0),
|
|
||||||
{1, static_cast<uint8_t>(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
duplicateIndex = kGroup_ - duplicateNum;
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Duplicate(groupIndex.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 将筛选出来的组序号降序排列
|
|
||||||
LocalTensor<float> sortedGroupIndex = sortedGroupIndexBuf_.Get<float>();
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Sort<float, true>(sortedGroupIndex, groupIndex.ReinterpretCast<float>(), groupIndex, sortTmp, Ceil(kGroup_, 32));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertIdx()
|
|
||||||
{
|
|
||||||
LocalTensor<float> sortedInGroupTensor = sortedInGroupBuf_.Get<float>();
|
|
||||||
LocalTensor<int32_t> sortedGroupIndex = sortedGroupIndexBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<float> mrgSort0Tensor = calcTmpBuf_.Get<float>();
|
|
||||||
|
|
||||||
uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0};
|
|
||||||
uint16_t lenArr[CONSTANT_FOUR] = {
|
|
||||||
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_),
|
|
||||||
static_cast<uint16_t>(perGroupExpertCount_), static_cast<uint16_t>(perGroupExpertCount_)};
|
|
||||||
MrgSort4Info params{lenArr, false, 0b1111, 1};
|
|
||||||
MrgSortSrcList<float> srcList;
|
|
||||||
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
|
|
||||||
for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) {
|
|
||||||
int64_t mrgLen = Min(i + 1, CONSTANT_FOUR);
|
|
||||||
if (mrgLen > 1) {
|
|
||||||
if (mrgLen == MERGE_LIST_FOUR) {
|
|
||||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
} else if (mrgLen == MERGE_LIST_THREE) {
|
|
||||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[3] = 0;
|
|
||||||
params.elementLengths[3] = 0;
|
|
||||||
params.validBit = 0b111;
|
|
||||||
} else {
|
|
||||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
offset[2] = 0;
|
|
||||||
offset[3] = 0;
|
|
||||||
params.elementLengths[2] = 0;
|
|
||||||
params.elementLengths[3] = 0;
|
|
||||||
params.validBit = 0b11;
|
|
||||||
}
|
|
||||||
|
|
||||||
srcList.src1 = sortedInGroupTensor[offset[0]];
|
|
||||||
srcList.src2 = sortedInGroupTensor[offset[1]];
|
|
||||||
srcList.src3 = sortedInGroupTensor[offset[2]];
|
|
||||||
srcList.src4 = sortedInGroupTensor[offset[3]];
|
|
||||||
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params);
|
|
||||||
} else {
|
|
||||||
offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2;
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]],
|
|
||||||
perGroupExpertCountAlign_ * 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int32_t baseLoop = 4;
|
|
||||||
LocalTensor<float> srcTensor = mrgSort0Tensor;
|
|
||||||
LocalTensor<float> dstTensor = mrgSort0Tensor;
|
|
||||||
for (int i = 0; i < tilingData_->vmsCount; i++) {
|
|
||||||
if (i % 2 == 0) {
|
|
||||||
srcTensor = mrgSort0Tensor;
|
|
||||||
dstTensor = sortedInGroupTensor;
|
|
||||||
} else {
|
|
||||||
srcTensor = sortedInGroupTensor;
|
|
||||||
dstTensor = mrgSort0Tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR;
|
|
||||||
int32_t quotient = kGroup_ / nextBaseRow;
|
|
||||||
int32_t remainder = kGroup_ - quotient * nextBaseRow;
|
|
||||||
if (quotient > 0) {
|
|
||||||
MrgSort4Info params;
|
|
||||||
MrgSortSrcList<float> srcList;
|
|
||||||
params.ifExhaustedSuspension = false;
|
|
||||||
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.validBit = 0b1111;
|
|
||||||
params.repeatTimes = 1;
|
|
||||||
for (int j = 0; j < quotient; j++) {
|
|
||||||
srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j];
|
|
||||||
srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)];
|
|
||||||
srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)];
|
|
||||||
srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)];
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remainder > 0) {
|
|
||||||
int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2;
|
|
||||||
int32_t mrgLen = CeilDiv(remainder, baseLoop);
|
|
||||||
int32_t tailRow = remainder - (mrgLen - 1) * baseLoop;
|
|
||||||
if (mrgLen > 1) {
|
|
||||||
MrgSort4Info params;
|
|
||||||
MrgSortSrcList<float> srcList;
|
|
||||||
params.repeatTimes = 1;
|
|
||||||
params.ifExhaustedSuspension = false;
|
|
||||||
params.elementLengths[0] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[1] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[2] = perGroupExpertCount_ * baseLoop;
|
|
||||||
params.elementLengths[3] = perGroupExpertCount_ * baseLoop;
|
|
||||||
srcList.src1 = srcTensor[baseOffset];
|
|
||||||
srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2];
|
|
||||||
if (mrgLen == MERGE_LIST_FOUR) {
|
|
||||||
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
|
|
||||||
srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3];
|
|
||||||
params.elementLengths[3] = perGroupExpertCount_ * tailRow;
|
|
||||||
params.validBit = 0b1111;
|
|
||||||
} else if (mrgLen == MERGE_LIST_THREE) {
|
|
||||||
srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2];
|
|
||||||
params.elementLengths[2] = perGroupExpertCount_ * tailRow;
|
|
||||||
params.elementLengths[3] = 0;
|
|
||||||
params.validBit = 0b111;
|
|
||||||
} else {
|
|
||||||
params.elementLengths[1] = perGroupExpertCount_ * tailRow;
|
|
||||||
params.elementLengths[2] = 0;
|
|
||||||
params.elementLengths[3] = 0;
|
|
||||||
params.validBit = 0b11;
|
|
||||||
}
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
MrgSort(dstTensor[baseOffset], srcList, params);
|
|
||||||
} else {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
baseLoop = nextBaseRow;
|
|
||||||
}
|
|
||||||
|
|
||||||
GatherMaskParams gatherMaskParams;
|
|
||||||
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES);
|
|
||||||
gatherMaskParams.src0BlockStride = 1;
|
|
||||||
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
|
|
||||||
gatherMaskParams.src1RepeatStride = 0;
|
|
||||||
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
uint8_t src1Pattern = 2; // 内置固定模式
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
GatherMask(topKExpertId, dstTensor.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
|
||||||
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::SelectTopKExpertScore()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
|
|
||||||
k_);
|
|
||||||
bool needRenorm = (normType_ == 1 ) || // 情况1:sigmoid + renorm
|
|
||||||
(normType_ == 0 && renorm_ == 1); // 情况3:softmax + renorm
|
|
||||||
if (needRenorm) {
|
|
||||||
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
|
|
||||||
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[32];
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Duplicate(tmpTensor, sumValue, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Div(yOutTensor, yOutTensor, tmpTensor, k_);
|
|
||||||
}
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
|
|
||||||
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
|
|
||||||
}
|
|
||||||
|
|
||||||
yOutQueue_.EnQue<float>(yOutTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CumputeActualTopKExpertId()
|
|
||||||
{
|
|
||||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
|
|
||||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<float> topKExpertIdFp32 = calcTmpBuf_.Get<float>();
|
|
||||||
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(expertIdxOut, expertIdxOut, static_cast<int32_t>(perGroupExpertCountAlign_ - perGroupExpertCount_), k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Sub(expertIdxOut, topKExpertId, expertIdxOut, k_);
|
|
||||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::CopyOut(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
|
||||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
|
||||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
|
||||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
|
|
||||||
yOutQueue_.FreeTensor(yOutTensor);
|
|
||||||
expertIdxOutQueue_.FreeTensor(expertIdxOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
|
||||||
GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
|
||||||
{
|
|
||||||
tilingData_ = tilingData;
|
|
||||||
pipe_ = tPipe;
|
|
||||||
blockIdx_ = GetBlockIdx();
|
|
||||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
if (blockIdx_ == GetBlockNum() - 1) {
|
|
||||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
|
||||||
} else {
|
|
||||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
}
|
|
||||||
expertCount_ = tilingData_->expertCount;
|
|
||||||
addBias_ = tilingData_->addBias == 1;
|
|
||||||
k_ = tilingData_->k;
|
|
||||||
kGroup_ = tilingData_->kGroup;
|
|
||||||
groupCount_ = tilingData_->groupCount;
|
|
||||||
groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
|
||||||
perGroupExpertCount_ = tilingData_->perGroupExpertCount;
|
|
||||||
perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign;
|
|
||||||
renorm_ = tilingData_->renorm;
|
|
||||||
normType_ = tilingData_->normType;
|
|
||||||
groupSelectMode_ = tilingData_->groupSelectMode;
|
|
||||||
|
|
||||||
expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float));
|
|
||||||
kAlign_ = Align(k_, sizeof(float));
|
|
||||||
|
|
||||||
isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_;
|
|
||||||
|
|
||||||
// init input gm buf
|
|
||||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
|
||||||
|
|
||||||
// init output gm buf
|
|
||||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
|
|
||||||
// init que
|
|
||||||
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float));
|
|
||||||
pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t));
|
|
||||||
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
|
|
||||||
pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t)));
|
|
||||||
|
|
||||||
pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO);
|
|
||||||
pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t));
|
|
||||||
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKGenerlized<T>::Process()
|
|
||||||
{
|
|
||||||
CopyInBiasAndInitExpertId();
|
|
||||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
|
||||||
CopyInX(row);
|
|
||||||
ComputeX();
|
|
||||||
if (tilingData_->outFlag) {
|
|
||||||
CopuOutXNorm(row);
|
|
||||||
}
|
|
||||||
SortInGroup();
|
|
||||||
SelectTopKGroupIndex();
|
|
||||||
SelectTopKExpertIdx();
|
|
||||||
SelectTopKExpertScore();
|
|
||||||
CumputeActualTopKExpertId();
|
|
||||||
CopyOut(row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace MoeGatingTopK
|
|
||||||
#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H
|
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file moe_gating_top_k_without_group.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
|
||||||
#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
|
||||||
#include "kernel_operator.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "kernel_utils.h"
|
|
||||||
namespace MoeGatingTopK {
|
|
||||||
using namespace AscendC;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class MoeGatingTopKWithoutGroup {
|
|
||||||
public:
|
|
||||||
__aicore__ inline MoeGatingTopKWithoutGroup(){};
|
|
||||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe);
|
|
||||||
__aicore__ inline void Process();
|
|
||||||
|
|
||||||
private:
|
|
||||||
__aicore__ inline void CopyInBiasAndInitExpertId();
|
|
||||||
__aicore__ inline void CopyInX(int64_t progress);
|
|
||||||
__aicore__ inline void ComputeX();
|
|
||||||
__aicore__ inline void CopuOutXNorm(int64_t row);
|
|
||||||
__aicore__ inline void SelectTopKExpertIdx();
|
|
||||||
__aicore__ inline void SelectTopKExpertScore();
|
|
||||||
__aicore__ inline void CopyOut(int64_t row);
|
|
||||||
|
|
||||||
private:
|
|
||||||
TPipe *pipe_;
|
|
||||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> yOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> expertIdxOutQueue_;
|
|
||||||
TQue<QuePosition::VECOUT, 1> outOutQueue_;
|
|
||||||
|
|
||||||
TBuf<TPosition::VECCALC> biasBuf_; // 存放输入bias
|
|
||||||
TBuf<TPosition::VECCALC> expertIdBuf_; // 专家编号
|
|
||||||
TBuf<TPosition::VECCALC> xNormWithBiasBuf_; // 存放加了bias之后的值
|
|
||||||
TBuf<TPosition::VECCALC> xNormBuf_; // 存放计算sigmoid或softmax的值
|
|
||||||
TBuf<TPosition::VECCALC> topKExpertIdBuf_;
|
|
||||||
TBuf<TPosition::VECCALC> calcTmpBuf_;
|
|
||||||
|
|
||||||
GlobalTensor<T> xGm_;
|
|
||||||
GlobalTensor<T> biasGm_;
|
|
||||||
GlobalTensor<T> yGm_;
|
|
||||||
GlobalTensor<int32_t> expertIdxGm_;
|
|
||||||
GlobalTensor<float> outGm_;
|
|
||||||
|
|
||||||
int64_t blockIdx_ = 0;
|
|
||||||
int64_t perCoreRowCount_ = 0;
|
|
||||||
int64_t curCoreRowCount_ = 0;
|
|
||||||
int64_t expertCount_ = 0;
|
|
||||||
bool addBias_ = false;
|
|
||||||
bool outFlag_ = false;
|
|
||||||
int64_t k_ = 0;
|
|
||||||
int64_t renorm_ = 0;
|
|
||||||
int64_t normType_ = 0;
|
|
||||||
int64_t expertCountAlign_ = 0;
|
|
||||||
const MoeGatingTopKTilingData *tilingData_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInBiasAndInitExpertId()
|
|
||||||
{
|
|
||||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
|
||||||
LocalTensor<int32_t> expertIdTensor = expertIdBuf_.Get<int32_t>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if (addBias_) {
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast<T>(), biasGm_, dataCopyParams, dataCopyPadParams);
|
|
||||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
|
||||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
|
||||||
Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
|
||||||
expertCountAlign_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ArithProgression(expertIdTensor, static_cast<int32_t>(0), static_cast<int32_t>(1), expertCount_);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyInX(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.AllocTensor<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
|
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
|
||||||
DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams);
|
|
||||||
} else {
|
|
||||||
DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), xGm_[row * expertCount_], dataCopyParams,
|
|
||||||
dataCopyPadParams);
|
|
||||||
}
|
|
||||||
xInQueue_.EnQue(xInLocalTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::ComputeX()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
LocalTensor<float> xInLocalTensor = xInQueue_.DeQue<float>();
|
|
||||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
|
||||||
LocalTensor<float> biasTensor = biasBuf_.Get<float>();
|
|
||||||
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast<T>(), RoundMode::CAST_NONE,
|
|
||||||
expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (normType_ == 1) { // sigmoid
|
|
||||||
LocalTensor<uint8_t> calcNormTmpTensor = calcTmpBuf_.Get<uint8_t>();
|
|
||||||
Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
} else if (normType_ == 0) { // sigmoid
|
|
||||||
LocalTensor<float> reduceValueTensor = calcTmpBuf_.Get<float>();
|
|
||||||
LocalTensor<float> calcTmp = calcTmpBuf_.Get<float>()[8];
|
|
||||||
ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_);
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float maxValue = reduceValueTensor.GetValue(0);
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Exp(xNormTensor, xNormTensor, expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_);
|
|
||||||
eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float sumValue = reduceValueTensor.GetValue(0);
|
|
||||||
eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
if (addBias_) {
|
|
||||||
Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_);
|
|
||||||
} else {
|
|
||||||
DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM;
|
|
||||||
int duplicateIndex = expertCount_ - duplicateNum;
|
|
||||||
if (duplicateNum > 0) {
|
|
||||||
uint64_t mask0 = UINT64_MAX;
|
|
||||||
mask0 = mask0 << duplicateNum;
|
|
||||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
|
||||||
uint64_t mask[2] = {mask0, 0};
|
|
||||||
Duplicate(xNormWithBiasTensor.ReinterpretCast<int32_t>()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
}
|
|
||||||
xInQueue_.FreeTensor(xInLocalTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopuOutXNorm(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<float> outOutTensor = outOutQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
DataCopy(outOutTensor, xNormTensor, expertCountAlign_);
|
|
||||||
outOutQueue_.EnQue<float>(outOutTensor);
|
|
||||||
outOutTensor = outOutQueue_.DeQue<float>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(expertCount_ * sizeof(float)), 0, 0, 0};
|
|
||||||
DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams);
|
|
||||||
outOutQueue_.FreeTensor(outOutTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertIdx()
|
|
||||||
{
|
|
||||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.AllocTensor<int32_t>();
|
|
||||||
LocalTensor<float> xNormWithBiasTensor = xNormWithBiasBuf_.Get<float>();
|
|
||||||
LocalTensor<uint32_t> expertIdTensor = expertIdBuf_.Get<uint32_t>();
|
|
||||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<float> sortedScore = calcTmpBuf_.Get<float>();
|
|
||||||
LocalTensor<float> sortTmp = calcTmpBuf_.Get<float>()[expertCountAlign_ * CONSTANT_TWO];
|
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
Sort<float, true>(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp,
|
|
||||||
expertCountAlign_ / ONE_REPEAT_SORT_NUM);
|
|
||||||
|
|
||||||
GatherMaskParams gatherMaskParams;
|
|
||||||
gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES);
|
|
||||||
gatherMaskParams.src0BlockStride = 1;
|
|
||||||
gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS;
|
|
||||||
gatherMaskParams.src1RepeatStride = 0;
|
|
||||||
|
|
||||||
uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数
|
|
||||||
uint8_t src1Pattern = 2; // 内置固定模式
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
GatherMask(topKExpertId, sortedScore.template ReinterpretCast<int32_t>(), src1Pattern, false,
|
|
||||||
static_cast<uint32_t>(0), gatherMaskParams, rsvdCnt);
|
|
||||||
|
|
||||||
DataCopy(expertIdxOut, topKExpertId, expertCountAlign_);
|
|
||||||
expertIdxOutQueue_.EnQue<int32_t>(expertIdxOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::SelectTopKExpertScore()
|
|
||||||
{
|
|
||||||
LocalTensor<float> xNormTensor = xNormBuf_.Get<float>();
|
|
||||||
LocalTensor<float> yOutTensor = yOutQueue_.AllocTensor<float>();
|
|
||||||
LocalTensor<int32_t> topKExpertId = topKExpertIdBuf_.Get<int32_t>();
|
|
||||||
LocalTensor<int32_t> topKExpertIdWithByte = calcTmpBuf_.Get<int32_t>();
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(topKExpertIdWithByte, topKExpertId, static_cast<int32_t>(sizeof(float)), k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast<uint32_t>(), static_cast<uint32_t>(0),
|
|
||||||
k_);
|
|
||||||
|
|
||||||
bool needRenorm = (normType_ == 1 ) || // 情况1:sigmoid + renorm
|
|
||||||
(normType_ == 0 && renorm_ == 1); // 情况3:softmax + renorm
|
|
||||||
if (needRenorm == 1) {
|
|
||||||
LocalTensor<float> maxValueTensor = calcTmpBuf_.Get<float>();
|
|
||||||
LocalTensor<float> tmpTensor = calcTmpBuf_.Get<float>()[BLOCK_BYTES];
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_);
|
|
||||||
event_t eventIdVToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
|
||||||
SetFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
WaitFlag<HardEvent::V_S>(eventIdVToS);
|
|
||||||
float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps;
|
|
||||||
event_t eventIdSToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
|
||||||
SetFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
WaitFlag<HardEvent::S_V>(eventIdSToV);
|
|
||||||
Duplicate(tmpTensor, sumValue, k_);
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Div(yOutTensor, yOutTensor, tmpTensor, k_);
|
|
||||||
}
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_);
|
|
||||||
|
|
||||||
if constexpr (!IsSameType<T, float>::value) {
|
|
||||||
PipeBarrier<PIPE_V>();
|
|
||||||
Cast(yOutTensor.ReinterpretCast<T>(), yOutTensor, RoundMode::CAST_RINT, k_);
|
|
||||||
}
|
|
||||||
|
|
||||||
yOutQueue_.EnQue<float>(yOutTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::CopyOut(int64_t row)
|
|
||||||
{
|
|
||||||
LocalTensor<T> yOutTensor = yOutQueue_.DeQue<T>();
|
|
||||||
LocalTensor<int32_t> expertIdxOut = expertIdxOutQueue_.DeQue<int32_t>();
|
|
||||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(k_ * sizeof(T)), 0, 0, 0};
|
|
||||||
DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams);
|
|
||||||
dataCopyParams.blockLen = k_ * sizeof(int32_t);
|
|
||||||
DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams);
|
|
||||||
yOutQueue_.FreeTensor(yOutTensor);
|
|
||||||
expertIdxOutQueue_.FreeTensor(expertIdxOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx,
|
|
||||||
GM_ADDR out, GM_ADDR workspace,
|
|
||||||
const MoeGatingTopKTilingData *tilingData, TPipe *tPipe)
|
|
||||||
{
|
|
||||||
tilingData_ = tilingData;
|
|
||||||
pipe_ = tPipe;
|
|
||||||
blockIdx_ = GetBlockIdx();
|
|
||||||
perCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
if (blockIdx_ == GetBlockNum() - 1) {
|
|
||||||
curCoreRowCount_ = tilingData_->lastCoreRowCount;
|
|
||||||
} else {
|
|
||||||
curCoreRowCount_ = tilingData_->perCoreRowCount;
|
|
||||||
}
|
|
||||||
expertCount_ = tilingData_->expertCount;
|
|
||||||
addBias_ = tilingData_->addBias == 1;
|
|
||||||
outFlag_ = tilingData_->outFlag == 1;
|
|
||||||
k_ = tilingData_->k;
|
|
||||||
renorm_ = tilingData_->renorm;
|
|
||||||
normType_ = tilingData_->normType;
|
|
||||||
expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
|
||||||
|
|
||||||
// init input gm buf
|
|
||||||
xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_);
|
|
||||||
|
|
||||||
// init output gm buf
|
|
||||||
yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_);
|
|
||||||
outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_);
|
|
||||||
|
|
||||||
// init que
|
|
||||||
pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float));
|
|
||||||
pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t));
|
|
||||||
pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float));
|
|
||||||
|
|
||||||
// init calc buf
|
|
||||||
pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T)));
|
|
||||||
pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t));
|
|
||||||
pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float));
|
|
||||||
pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float));
|
|
||||||
pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t));
|
|
||||||
|
|
||||||
// init tmp buf
|
|
||||||
pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__aicore__ inline void MoeGatingTopKWithoutGroup<T>::Process()
|
|
||||||
{
|
|
||||||
CopyInBiasAndInitExpertId();
|
|
||||||
for (int64_t row = 0; row < curCoreRowCount_; row++) {
|
|
||||||
CopyInX(row);
|
|
||||||
ComputeX();
|
|
||||||
if (outFlag_) {
|
|
||||||
CopuOutXNorm(row);
|
|
||||||
}
|
|
||||||
SelectTopKExpertIdx();
|
|
||||||
SelectTopKExpertScore();
|
|
||||||
CopyOut(row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace MoeGatingTopK
|
|
||||||
#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file data_copy_transpose_tiling.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <graph/tensor.h>
|
|
||||||
#include "data_copy_transpose_tiling_def.h"
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize,
|
|
||||||
optiling::CopyTransposeTiling &tiling)
|
|
||||||
{
|
|
||||||
constexpr int64_t B_INDEX = 0;
|
|
||||||
constexpr int64_t N_INDEX = 1;
|
|
||||||
constexpr int64_t S_INDEX = 2;
|
|
||||||
constexpr int64_t H_INDEX = 3;
|
|
||||||
std::vector<int64_t> dstShapeInfo = dstShape.GetDims();
|
|
||||||
std::vector<int64_t> srcShapeInfo = srcShape.GetDims();
|
|
||||||
|
|
||||||
tiling.set_dstShapeB(dstShapeInfo[B_INDEX]);
|
|
||||||
tiling.set_dstShapeN(dstShapeInfo[N_INDEX]);
|
|
||||||
tiling.set_dstShapeS(dstShapeInfo[S_INDEX]);
|
|
||||||
tiling.set_dstShapeH(dstShapeInfo[H_INDEX]);
|
|
||||||
tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN());
|
|
||||||
|
|
||||||
tiling.set_srcShapeB(srcShapeInfo[B_INDEX]);
|
|
||||||
tiling.set_srcShapeN(srcShapeInfo[N_INDEX]);
|
|
||||||
tiling.set_srcShapeS(srcShapeInfo[S_INDEX]);
|
|
||||||
tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]);
|
|
||||||
tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize);
|
|
||||||
tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH());
|
|
||||||
tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS());
|
|
||||||
tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN());
|
|
||||||
tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file data_copy_transpose_tiling_def.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <register/tilingdata_base.h>
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
BEGIN_TILING_DATA_DEF(CopyTransposeTiling)
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeB);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeN);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeS);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, dstShapeH);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeB);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeN);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeS);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue);
|
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, paramsAlign);
|
|
||||||
END_TILING_DATA_DEF;
|
|
||||||
REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling)
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "toolchain/slog.h"
|
|
||||||
|
|
||||||
#define OP_LOGI(opname, ...)
|
|
||||||
#define OP_LOGW(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGE(opname, ...) \
|
|
||||||
do { \
|
|
||||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
|
||||||
printf("\n"); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define OP_LOGD(opname, ...)
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
|
||||||
do { \
|
|
||||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
// 修改 OP_TILING_CHECK 宏,确保正确处理表达式
|
|
||||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
|
||||||
do { \
|
|
||||||
if (cond) { \
|
|
||||||
log_func; \
|
|
||||||
expr; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
|
||||||
do { \
|
|
||||||
if ((ptr) == nullptr) { \
|
|
||||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
|
||||||
return ge::GRAPH_FAILED; \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
|
|
||||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
|
||||||
@@ -1,256 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file tiling_base.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
#include <exe_graph/runtime/tiling_context.h>
|
|
||||||
#include <graph/utils/type_utils.h>
|
|
||||||
#include "tiling/platform/platform_ascendc.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
|
|
||||||
#ifdef ASCENDC_OP_TEST
|
|
||||||
#define ASCENDC_EXTERN_C extern "C"
|
|
||||||
#else
|
|
||||||
#define ASCENDC_EXTERN_C
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace Ops {
|
|
||||||
namespace Transformer {
|
|
||||||
namespace OpTiling {
|
|
||||||
|
|
||||||
struct AiCoreParams {
|
|
||||||
uint64_t ubSize = 0;
|
|
||||||
uint64_t blockDim = 0;
|
|
||||||
uint64_t aicNum = 0;
|
|
||||||
uint64_t l1Size = 0;
|
|
||||||
uint64_t l0aSize = 0;
|
|
||||||
uint64_t l0bSize = 0;
|
|
||||||
uint64_t l0cSize = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CompileInfoCommon {
|
|
||||||
uint32_t aivNum;
|
|
||||||
uint32_t aicNum;
|
|
||||||
uint64_t ubSize;
|
|
||||||
uint64_t l1Size;
|
|
||||||
uint64_t l0aSize;
|
|
||||||
uint64_t l0bSize;
|
|
||||||
uint64_t l0cSize;
|
|
||||||
uint64_t l2CacheSize;
|
|
||||||
int64_t coreNum;
|
|
||||||
int32_t socVersion;
|
|
||||||
uint32_t rsvd;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FlashAttentionScoreGradCompileInfo {
|
|
||||||
uint32_t aivNum;
|
|
||||||
uint32_t aicNum;
|
|
||||||
uint64_t ubSize;
|
|
||||||
uint64_t l1Size;
|
|
||||||
uint64_t l0aSize;
|
|
||||||
uint64_t l0bSize;
|
|
||||||
uint64_t l0cSize;
|
|
||||||
uint64_t l2CacheSize;
|
|
||||||
int64_t coreNum;
|
|
||||||
platform_ascendc::SocVersion socVersion;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FACompileInfoCommon {
|
|
||||||
uint32_t aivNum;
|
|
||||||
uint32_t aicNum;
|
|
||||||
uint64_t ubSize;
|
|
||||||
uint64_t l1Size;
|
|
||||||
uint64_t l0aSize;
|
|
||||||
uint64_t l0bSize;
|
|
||||||
uint64_t l0cSize;
|
|
||||||
uint64_t l2CacheSize;
|
|
||||||
int64_t coreNum;
|
|
||||||
int32_t socVersion;
|
|
||||||
uint32_t rsvd;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TilingBaseClass {
|
|
||||||
public:
|
|
||||||
explicit TilingBaseClass(gert::TilingContext* context) : context_(context)
|
|
||||||
{}
|
|
||||||
|
|
||||||
virtual ~TilingBaseClass() = default;
|
|
||||||
|
|
||||||
// Tiling执行框架
|
|
||||||
// 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现
|
|
||||||
// 2、GRAPH_FAILED: 失败,中止整个Tiling流程
|
|
||||||
// 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现
|
|
||||||
ge::graphStatus DoTiling()
|
|
||||||
{
|
|
||||||
auto ret = GetShapeAttrsInfo();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
ret = GetPlatformInfo();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
if (!IsCapable()) {
|
|
||||||
return ge::GRAPH_PARAM_INVALID;
|
|
||||||
}
|
|
||||||
ret = DoOpTiling();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
ret = DoLibApiTiling();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
ret = GetWorkspaceSize();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
ret = PostTiling();
|
|
||||||
if (ret != ge::GRAPH_SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
context_->SetTilingKey(GetTilingKey());
|
|
||||||
DumpTilingInfo();
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新 context
|
|
||||||
virtual void Reset(gert::TilingContext* context)
|
|
||||||
{
|
|
||||||
context_ = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual bool IsCapable() = 0;
|
|
||||||
// 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小
|
|
||||||
virtual ge::graphStatus GetPlatformInfo() = 0;
|
|
||||||
// 2、获取INPUT/OUTPUT/ATTR信息
|
|
||||||
virtual ge::graphStatus GetShapeAttrsInfo() = 0;
|
|
||||||
// 3、计算数据切分TilingData
|
|
||||||
virtual ge::graphStatus DoOpTiling() = 0;
|
|
||||||
// 4、计算高阶API的TilingData
|
|
||||||
virtual ge::graphStatus DoLibApiTiling() = 0;
|
|
||||||
// 5、计算TilingKey
|
|
||||||
[[nodiscard]] virtual uint64_t GetTilingKey() const = 0;
|
|
||||||
// 6、计算Workspace 大小
|
|
||||||
virtual ge::graphStatus GetWorkspaceSize() = 0;
|
|
||||||
// 7、保存Tiling数据
|
|
||||||
virtual ge::graphStatus PostTiling() = 0;
|
|
||||||
// 8、Dump Tiling数据
|
|
||||||
virtual void DumpTilingInfo()
|
|
||||||
{
|
|
||||||
int32_t enable = CheckLogLevel(static_cast<int32_t>(OP), DLOG_DEBUG);
|
|
||||||
if (enable != 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto buf = (uint32_t*)context_->GetRawTilingData()->GetData();
|
|
||||||
auto bufLen = context_->GetRawTilingData()->GetDataSize();
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen
|
|
||||||
<< ", content:";
|
|
||||||
for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) {
|
|
||||||
oss << *(buf + i) << ",";
|
|
||||||
if (oss.str().length() > 640) { // Split according to 640 to avoid truncation
|
|
||||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
|
||||||
oss.str("");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OP_LOGD(context_, "%s", oss.str().c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum)
|
|
||||||
{
|
|
||||||
uint32_t ration;
|
|
||||||
if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) {
|
|
||||||
return sliceNum;
|
|
||||||
}
|
|
||||||
ration = aivCoreNum / aicCoreNum;
|
|
||||||
return (sliceNum + (ration - 1)) / ration;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
[[nodiscard]] std::string GetShapeDebugStr(const T& shape) const
|
|
||||||
{
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << "[";
|
|
||||||
if (shape.GetDimNum() > 0) {
|
|
||||||
for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) {
|
|
||||||
oss << shape.GetDim(i) << ", ";
|
|
||||||
}
|
|
||||||
oss << shape.GetDim(shape.GetDimNum() - 1);
|
|
||||||
}
|
|
||||||
oss << "]";
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] std::string GetTensorDebugStr(
|
|
||||||
const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor)
|
|
||||||
{
|
|
||||||
if (shape == nullptr || tensor == nullptr) {
|
|
||||||
return "nil ";
|
|
||||||
}
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),";
|
|
||||||
oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),";
|
|
||||||
oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),";
|
|
||||||
oss << "(format: "
|
|
||||||
<< ge::TypeUtils::FormatToSerialString(
|
|
||||||
static_cast<ge::Format>(ge::GetPrimaryFormat(tensor->GetStorageFormat())))
|
|
||||||
<< "),";
|
|
||||||
oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") ";
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] std::string GetTilingContextDebugStr()
|
|
||||||
{
|
|
||||||
std::ostringstream oss;
|
|
||||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) {
|
|
||||||
oss << "input" << i << ": ";
|
|
||||||
oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) {
|
|
||||||
oss << "output" << i << ": ";
|
|
||||||
oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i));
|
|
||||||
}
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
[[nodiscard]] std::string GetTilingDataDebugStr() const
|
|
||||||
{
|
|
||||||
auto rawTilingData = context_->GetRawTilingData();
|
|
||||||
auto rawTilingDataSize = rawTilingData->GetDataSize();
|
|
||||||
auto data = reinterpret_cast<const int32_t*>(rawTilingData->GetData());
|
|
||||||
size_t len = rawTilingDataSize / sizeof(int32_t);
|
|
||||||
std::ostringstream oss;
|
|
||||||
for (size_t i = 0; i < len; i++) {
|
|
||||||
oss << data[i] << ", ";
|
|
||||||
}
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
gert::TilingContext* context_ = nullptr;
|
|
||||||
std::unique_ptr<platform_ascendc::PlatformAscendC> ascendcPlatform_{nullptr};
|
|
||||||
uint32_t blockDim_{0};
|
|
||||||
uint64_t workspaceSize_{0};
|
|
||||||
uint64_t tilingKey_{0};
|
|
||||||
AiCoreParams aicoreParams_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace OpTiling
|
|
||||||
} // namespace Transformer
|
|
||||||
} // namespace Ops
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file tiling_key.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace Ops {
|
|
||||||
namespace Transformer {
|
|
||||||
namespace OpTiling {
|
|
||||||
constexpr uint64_t RecursiveSum()
|
|
||||||
{
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr uint64_t kBase = 10; // 10进制进位基数
|
|
||||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
|
||||||
{
|
|
||||||
return static_cast<uint64_t>(templateId) + kBase * RecursiveSum(templateIds...);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TilingKey 的生成规则:
|
|
||||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
|
||||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
|
||||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
|
||||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
|
||||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
|
||||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
|
||||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
|
||||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
|
||||||
// 其余特化场景,定义自己的位域和值
|
|
||||||
// usage: get tilingKey from inputed types
|
|
||||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
|
||||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
|
||||||
|
|
||||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
|
||||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
|
||||||
{
|
|
||||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
|
||||||
}
|
|
||||||
|
|
||||||
// usage: get tilingKey from inputed types
|
|
||||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
|
||||||
|
|
||||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
|
||||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
|
||||||
SparseEnum::sparse))
|
|
||||||
|
|
||||||
} // namespace Optiling
|
|
||||||
} // namespace Transformer
|
|
||||||
} // namespace Ops
|
|
||||||
@@ -1,351 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file tiling_templates_registry.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "exe_graph/runtime/tiling_context.h"
|
|
||||||
#include "tiling_base.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
|
|
||||||
namespace Ops {
|
|
||||||
namespace Transformer {
|
|
||||||
namespace OpTiling {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::unique_ptr<TilingBaseClass> TILING_CLASS(gert::TilingContext* context)
|
|
||||||
{
|
|
||||||
return std::unique_ptr<T>(new (std::nothrow) T(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
using TilingClassCase = std::unique_ptr<TilingBaseClass> (*)(gert::TilingContext*);
|
|
||||||
|
|
||||||
class TilingCases {
|
|
||||||
public:
|
|
||||||
explicit TilingCases(std::string op_type) : op_type_(std::move(op_type))
|
|
||||||
{}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void AddTiling(int32_t priority)
|
|
||||||
{
|
|
||||||
OP_CHECK_IF(
|
|
||||||
cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return);
|
|
||||||
cases_[priority] = TILING_CLASS<T>;
|
|
||||||
OP_CHECK_IF(
|
|
||||||
cases_[priority] == nullptr,
|
|
||||||
OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::map<int32_t, TilingClassCase>& GetTilingCases()
|
|
||||||
{
|
|
||||||
return cases_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::map<int32_t, TilingClassCase> cases_;
|
|
||||||
const std::string op_type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// --------------------------------Interfacce with soc version --------------------------------
|
|
||||||
class TilingRegistryNew {
|
|
||||||
public:
|
|
||||||
TilingRegistryNew() = default;
|
|
||||||
|
|
||||||
#ifdef ASCENDC_OP_TEST
|
|
||||||
static TilingRegistryNew& GetInstance();
|
|
||||||
#else
|
|
||||||
static TilingRegistryNew& GetInstance()
|
|
||||||
{
|
|
||||||
static TilingRegistryNew registry_impl_;
|
|
||||||
return registry_impl_;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type, int32_t soc_version)
|
|
||||||
{
|
|
||||||
auto soc_iter = registry_map_.find(soc_version);
|
|
||||||
if (soc_iter == registry_map_.end()) {
|
|
||||||
std::map<std::string, std::shared_ptr<TilingCases>> op_type_map;
|
|
||||||
op_type_map[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
|
||||||
registry_map_[soc_version] = op_type_map;
|
|
||||||
} else {
|
|
||||||
if (soc_iter->second.find(op_type) == soc_iter->second.end()) {
|
|
||||||
soc_iter->second[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_CHECK_IF(
|
|
||||||
registry_map_[soc_version][op_type] == nullptr,
|
|
||||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
|
||||||
return registry_map_[soc_version][op_type];
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
|
||||||
{
|
|
||||||
int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION;
|
|
||||||
const char* op_type = context->GetNodeType();
|
|
||||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
|
||||||
if (platformInfoPtr == nullptr) {
|
|
||||||
auto compileInfoPtr = static_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
|
||||||
OP_CHECK_IF(
|
|
||||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
|
||||||
soc_version = compileInfoPtr->socVersion;
|
|
||||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
|
||||||
} else {
|
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
|
||||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
|
||||||
OP_LOGD(context, "soc version is %d", soc_version);
|
|
||||||
if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) {
|
|
||||||
OP_LOGE(op_type, "Do op tiling failed, cannot find soc version.");
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
|
||||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
|
||||||
auto tilingTemplate = it->second(context);
|
|
||||||
if (tilingTemplate != nullptr) {
|
|
||||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
|
||||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
|
||||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
|
||||||
{
|
|
||||||
int32_t soc_version;
|
|
||||||
const char* op_type = context->GetNodeType();
|
|
||||||
auto platformInfoPtr = context->GetPlatformInfo();
|
|
||||||
if (platformInfoPtr == nullptr) {
|
|
||||||
auto compileInfoPtr = reinterpret_cast<const CompileInfoCommon*>(context->GetCompileInfo());
|
|
||||||
OP_CHECK_IF(
|
|
||||||
compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED);
|
|
||||||
soc_version = compileInfoPtr->socVersion;
|
|
||||||
OP_LOGD(context, "soc version in compileInfo is %d", soc_version);
|
|
||||||
} else {
|
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
|
||||||
soc_version = static_cast<int32_t>(ascendcPlatform.GetSocVersion());
|
|
||||||
OP_LOGD(context, "soc version is %d", soc_version);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version);
|
|
||||||
for (auto priority_id : priorities) {
|
|
||||||
auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id);
|
|
||||||
if (tilingCaseIter != tilingTemplateRegistryMap.end()) {
|
|
||||||
auto templateFunc = tilingCaseIter->second(context);
|
|
||||||
if (templateFunc != nullptr) {
|
|
||||||
ge::graphStatus status = templateFunc->DoTiling();
|
|
||||||
if (status == ge::GRAPH_SUCCESS) {
|
|
||||||
OP_LOGD(context, "Do general op tiling success priority=%d", priority_id);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type, int32_t soc_version)
|
|
||||||
{
|
|
||||||
auto soc_iter = registry_map_.find(soc_version);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
soc_iter == registry_map_.end(),
|
|
||||||
OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version),
|
|
||||||
return empty_tiling_case_);
|
|
||||||
auto op_iter = soc_iter->second.find(op_type);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."),
|
|
||||||
return empty_tiling_case_);
|
|
||||||
return op_iter->second->GetTilingCases();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::map<int32_t, std::map<std::string, std::shared_ptr<TilingCases>>> registry_map_; // key is socversion
|
|
||||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_{};
|
|
||||||
};
|
|
||||||
|
|
||||||
class RegisterNew {
|
|
||||||
public:
|
|
||||||
explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type))
|
|
||||||
{}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
RegisterNew& tiling(int32_t priority, int32_t soc_version)
|
|
||||||
{
|
|
||||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
|
||||||
tilingCases->AddTiling<T>(priority);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
RegisterNew& tiling(int32_t priority, const std::vector<int32_t>& soc_versions)
|
|
||||||
{
|
|
||||||
for (int32_t soc_version : soc_versions) {
|
|
||||||
auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."),
|
|
||||||
return *this);
|
|
||||||
tilingCases->AddTiling<T>(priority);
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const std::string op_type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// --------------------------------Interfacce without soc version --------------------------------
|
|
||||||
class TilingRegistry {
|
|
||||||
public:
|
|
||||||
TilingRegistry() = default;
|
|
||||||
|
|
||||||
#ifdef ASCENDC_OP_TEST
|
|
||||||
static TilingRegistry& GetInstance();
|
|
||||||
#else
|
|
||||||
static TilingRegistry& GetInstance()
|
|
||||||
{
|
|
||||||
static TilingRegistry registry_impl_;
|
|
||||||
return registry_impl_;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::shared_ptr<TilingCases> RegisterOp(const std::string& op_type)
|
|
||||||
{
|
|
||||||
if (registry_map_.find(op_type) == registry_map_.end()) {
|
|
||||||
registry_map_[op_type] = std::shared_ptr<TilingCases>(new (std::nothrow) TilingCases(op_type));
|
|
||||||
}
|
|
||||||
OP_CHECK_IF(
|
|
||||||
registry_map_[op_type] == nullptr,
|
|
||||||
OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr);
|
|
||||||
return registry_map_[op_type];
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context)
|
|
||||||
{
|
|
||||||
const char* op_type = context->GetNodeType();
|
|
||||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
|
||||||
for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) {
|
|
||||||
auto tilingTemplate = it->second(context);
|
|
||||||
if (tilingTemplate != nullptr) {
|
|
||||||
ge::graphStatus status = tilingTemplate->DoTiling();
|
|
||||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
|
||||||
OP_LOGD(context, "Do general op tiling success priority=%d", it->first);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
OP_LOGD(context, "Ignore general op tiling priority=%d", it->first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector<int32_t>& priorities)
|
|
||||||
{
|
|
||||||
const char* op_type = context->GetNodeType();
|
|
||||||
auto tilingTemplateRegistryMap = GetTilingTemplates(op_type);
|
|
||||||
for (auto priorityId : priorities) {
|
|
||||||
auto templateFunc = tilingTemplateRegistryMap[priorityId](context);
|
|
||||||
if (templateFunc != nullptr) {
|
|
||||||
ge::graphStatus status = templateFunc->DoTiling();
|
|
||||||
if (status == ge::GRAPH_SUCCESS) {
|
|
||||||
OP_LOGD(context, "Do general op tiling success priority=%d", priorityId);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (status != ge::GRAPH_PARAM_INVALID) {
|
|
||||||
OP_LOGD(context, "Do op tiling failed");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OP_LOGE(op_type, "Do op tiling failed, no valid template is found.");
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::map<int32_t, TilingClassCase>& GetTilingTemplates(const std::string& op_type)
|
|
||||||
{
|
|
||||||
OP_CHECK_IF(
|
|
||||||
registry_map_.find(op_type) == registry_map_.end(),
|
|
||||||
OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_);
|
|
||||||
return registry_map_[op_type]->GetTilingCases();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::map<std::string, std::shared_ptr<TilingCases>> registry_map_;
|
|
||||||
const std::map<int32_t, TilingClassCase> empty_tiling_case_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Register {
|
|
||||||
public:
|
|
||||||
explicit Register(std::string op_type) : op_type_(std::move(op_type))
|
|
||||||
{}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
Register& tiling(int32_t priority)
|
|
||||||
{
|
|
||||||
auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_);
|
|
||||||
OP_CHECK_IF(
|
|
||||||
tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this);
|
|
||||||
tilingCases->AddTiling<T>(priority);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const std::string op_type_;
|
|
||||||
};
|
|
||||||
} // namespace OpTiling
|
|
||||||
} // namespace Transformer
|
|
||||||
} // namespace Ops
|
|
||||||
|
|
||||||
// op_type: 算子名称, class_name: 注册的 tiling 类, soc_version:芯片版本号
|
|
||||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
|
|
||||||
#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \
|
|
||||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
|
||||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
|
||||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_versions)
|
|
||||||
|
|
||||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
|
||||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
|
||||||
#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \
|
|
||||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
|
||||||
static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \
|
|
||||||
Ops::Transformer::OpTiling::Register(op_type).tiling<class_name>(priority)
|
|
||||||
|
|
||||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
|
||||||
// soc_version: soc版本,用于区分不同的soc
|
|
||||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类
|
|
||||||
#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \
|
|
||||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
|
||||||
static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \
|
|
||||||
Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling<class_name>(priority, soc_version)
|
|
||||||
|
|
||||||
// op_type: 算子名称, class_name: 注册的 tiling 类,
|
|
||||||
// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大
|
|
||||||
// 取代 REGISTER_TILING_TEMPLATE , 传入的op_type如果是字符串常量,需要去掉引号
|
|
||||||
#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \
|
|
||||||
[[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \
|
|
||||||
static Ops::Transformer::OpTiling::Register \
|
|
||||||
__attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \
|
|
||||||
Ops::Transformer::OpTiling::Register(#op_type).tiling<class_name>(priority)
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file tiling_type.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
enum class AxisEnum {
|
|
||||||
B = 0,
|
|
||||||
N2 = 1,
|
|
||||||
G = 2,
|
|
||||||
S1 = 3,
|
|
||||||
S2 = 4,
|
|
||||||
D = 5,
|
|
||||||
NONE = 9,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class DtypeEnum {
|
|
||||||
FLOAT16 = 0,
|
|
||||||
FLOAT32 = 1,
|
|
||||||
BFLOAT16 = 2,
|
|
||||||
FLOAT16_PRECISION = 3,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class PerformanceOrientedEnum {
|
|
||||||
BIG_BUFFER = 1,
|
|
||||||
BIG_DOUBLE_BUFFER = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class MatmulConfig {
|
|
||||||
NULL_CONFIG = 0,
|
|
||||||
NORMAL_CONFIG = 1,
|
|
||||||
MDL_CONFIG = 2
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class PseConfig {
|
|
||||||
NO_PSE = 0,
|
|
||||||
EXIST_PSE = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class AttenMaskConfig {
|
|
||||||
NO_ATTEN_MASK = 0,
|
|
||||||
EXIST_ATTEN_MASK = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class DropOutConfig {
|
|
||||||
NO_DROP_OUT = 0,
|
|
||||||
EXIST_DROP_OUT = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class CubeFormatEnum {
|
|
||||||
ND = 0,
|
|
||||||
NZ = 1
|
|
||||||
};
|
|
||||||
enum class LayoutEnum {
|
|
||||||
BSND = 0,
|
|
||||||
SBND = 1,
|
|
||||||
BNSD = 2,
|
|
||||||
TND = 3,
|
|
||||||
NTD_TND = 4
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class CubeInputSourceEnum {
|
|
||||||
GM = 0,
|
|
||||||
L1 = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class OptionEnum {
|
|
||||||
DISABLE = 0,
|
|
||||||
ENABLE = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class SparseEnum {
|
|
||||||
ALL = 0,
|
|
||||||
NONE = 1,
|
|
||||||
ANY = 2,
|
|
||||||
CAUSAL = 3,
|
|
||||||
BAND = 4,
|
|
||||||
PREFIX = 5,
|
|
||||||
BAND_COMPRESS = 6,
|
|
||||||
RIGHT_DOWN_CAUSAL = 7,
|
|
||||||
RIGHT_DOWN_CAUSAL_BAND = 8,
|
|
||||||
BAND_LEFT_UP_CAUSAL = 9
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr uint64_t RecursiveSum()
|
|
||||||
{
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int64_t base10Multiplier = 10;
|
|
||||||
|
|
||||||
template <typename T, typename... Args> constexpr uint64_t RecursiveSum(T templateId, Args... templateIds)
|
|
||||||
{
|
|
||||||
return static_cast<uint64_t>(templateId) + base10Multiplier * RecursiveSum(templateIds...);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TilingKey 的生成规则:
|
|
||||||
// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1,
|
|
||||||
// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1:
|
|
||||||
// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分,
|
|
||||||
// 那么填AXIS_NONE。UB0和UB1各占一个十进制位;
|
|
||||||
// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位;
|
|
||||||
// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位
|
|
||||||
// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位
|
|
||||||
// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位
|
|
||||||
// 其余特化场景,定义自己的位域和值
|
|
||||||
// usage: get tilingKey from inputed types
|
|
||||||
// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2,
|
|
||||||
// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL)
|
|
||||||
|
|
||||||
constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19
|
|
||||||
template <typename... Args> constexpr uint64_t GET_TILINGKEY(Args... templateIds)
|
|
||||||
{
|
|
||||||
return TILINGKEYOFFSET + RecursiveSum(templateIds...);
|
|
||||||
}
|
|
||||||
|
|
||||||
// usage: get tilingKey from inputed types
|
|
||||||
// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL)
|
|
||||||
|
|
||||||
#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \
|
|
||||||
(GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \
|
|
||||||
SparseEnum::sparse))
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file tiling_util.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
|
|
||||||
namespace Ops {
|
|
||||||
namespace Transformer {
|
|
||||||
namespace OpTiling {
|
|
||||||
bool IsRegbaseSocVersion(const gert::TilingParseContext* context);
|
|
||||||
|
|
||||||
bool IsRegbaseSocVersion(const gert::TilingContext* context);
|
|
||||||
|
|
||||||
const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);
|
|
||||||
} // namespace OpTiling
|
|
||||||
} // namespace Transformer
|
|
||||||
} // namespace Ops
|
|
||||||
@@ -1118,60 +1118,6 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons
|
|||||||
return combined_x;
|
return combined_x;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
|
|
||||||
const at::Tensor& x,
|
|
||||||
int64_t k,
|
|
||||||
int64_t kGroup,
|
|
||||||
int64_t groupCount,
|
|
||||||
int64_t groupSelectMode,
|
|
||||||
int64_t renorm,
|
|
||||||
int64_t normType,
|
|
||||||
bool outFlag,
|
|
||||||
double routedScalingFactor,
|
|
||||||
double eps,
|
|
||||||
const c10::optional<at::Tensor>& biasOptional
|
|
||||||
)
|
|
||||||
{
|
|
||||||
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
|
|
||||||
TORCH_CHECK(
|
|
||||||
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
|
|
||||||
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
|
|
||||||
x.scalar_type());
|
|
||||||
|
|
||||||
auto x_size = x.sizes();
|
|
||||||
auto rows = x_size[0];
|
|
||||||
auto expert_num = x_size[1];
|
|
||||||
const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); });
|
|
||||||
if (bias.defined()) {
|
|
||||||
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
|
|
||||||
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
|
|
||||||
auto bias_size = bias.sizes();
|
|
||||||
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
|
|
||||||
}
|
|
||||||
at::Tensor yOut = at::empty({rows, k}, x.options());
|
|
||||||
at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt));
|
|
||||||
at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
|
|
||||||
|
|
||||||
EXEC_NPU_CMD(aclnnMoeGatingTopK,
|
|
||||||
x, // input_x
|
|
||||||
biasOptional,
|
|
||||||
k, // k
|
|
||||||
kGroup, // k_group
|
|
||||||
groupCount, // group_count
|
|
||||||
groupSelectMode, // group_select_mode
|
|
||||||
renorm, // renorm
|
|
||||||
normType, // norm_type
|
|
||||||
outFlag, // out_flag
|
|
||||||
routedScalingFactor, // routed_scaling_factor
|
|
||||||
eps, // eps
|
|
||||||
yOut, // input_y (注意:这里应该是 yOut)
|
|
||||||
expertIdxOut, // output
|
|
||||||
outOut
|
|
||||||
);
|
|
||||||
|
|
||||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(outOut,expertIdxOut, yOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom(
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom(
|
||||||
const at::Tensor &x, const at::Tensor &expert_idx,
|
const at::Tensor &x, const at::Tensor &expert_idx,
|
||||||
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
|
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
|
||||||
@@ -1277,23 +1223,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
|||||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||||
{
|
{
|
||||||
// vLLM-Ascend custom ops
|
// vLLM-Ascend custom ops
|
||||||
ops.def(
|
|
||||||
"moe_gating_top_k(Tensor x, "
|
|
||||||
"int k, "
|
|
||||||
"int kGroup, "
|
|
||||||
"int groupCount, "
|
|
||||||
"int groupSelectMode, "
|
|
||||||
"int renorm, "
|
|
||||||
"int normType, "
|
|
||||||
"bool outFlag, "
|
|
||||||
"float routedScalingFactor, "
|
|
||||||
"float eps,"
|
|
||||||
"Tensor? biasOptional=None)"
|
|
||||||
|
|
||||||
"-> (Tensor outOut,Tensor expertIdxOut, Tensor yOut)"
|
|
||||||
);
|
|
||||||
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
|
|
||||||
//Moe_gating
|
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
||||||
|
|
||||||
|
|||||||
@@ -283,42 +283,6 @@ std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm_meta(
|
|||||||
return {output, add_out};
|
return {output, add_out};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
|
|
||||||
const at::Tensor& x,
|
|
||||||
int64_t k,
|
|
||||||
int64_t kGroup,
|
|
||||||
int64_t groupCount,
|
|
||||||
int64_t groupSelectMode,
|
|
||||||
int64_t renorm,
|
|
||||||
int64_t normType,
|
|
||||||
bool outFlag,
|
|
||||||
double routedScalingFactor,
|
|
||||||
double eps,
|
|
||||||
const c10::optional<at::Tensor>& biasOptional)
|
|
||||||
{
|
|
||||||
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
|
|
||||||
TORCH_CHECK(
|
|
||||||
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
|
|
||||||
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
|
|
||||||
x.scalar_type());
|
|
||||||
|
|
||||||
auto x_size = x.sizes();
|
|
||||||
auto rows = x_size[0];
|
|
||||||
auto expert_num = x_size[1];
|
|
||||||
const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); });
|
|
||||||
if (bias.defined()) {
|
|
||||||
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
|
|
||||||
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
|
|
||||||
auto bias_size = bias.sizes();
|
|
||||||
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
|
|
||||||
}
|
|
||||||
at::Tensor yOut = at::empty({rows, k}, x.options());
|
|
||||||
at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt));
|
|
||||||
at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
|
|
||||||
|
|
||||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(outOut,expertIdxOut, yOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom_meta(
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom_meta(
|
||||||
const at::Tensor &x, const at::Tensor &expert_idx,
|
const at::Tensor &x, const at::Tensor &expert_idx,
|
||||||
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
|
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
|
||||||
@@ -403,15 +367,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace meta
|
} // namespace meta
|
||||||
|
|
||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||||
// the custom kernel been captured into aclgraph
|
// the custom kernel been captured into aclgraph
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||||
// Moe_gating_top_k
|
|
||||||
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
|
|
||||||
// Rotary embedding meta implementation
|
// Rotary embedding meta implementation
|
||||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||||
// Masked input and mask meta implementation
|
// Masked input and mask meta implementation
|
||||||
|
|||||||
@@ -179,8 +179,7 @@ class SmallOps(DecodeMoeOps):
|
|||||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||||
quant_mode=2,
|
quant_mode=2,
|
||||||
global_bs=self.batch_size * self.ep_world_size,
|
global_bs=self.batch_size * self.ep_world_size,
|
||||||
expert_token_nums_type=
|
expert_token_nums_type=1, # 0代表前缀和,1代表各自数量
|
||||||
1, # 0 represents prefix sum, 1 represents individual counts
|
|
||||||
)
|
)
|
||||||
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
|
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
|
||||||
output_dtype = x.dtype
|
output_dtype = x.dtype
|
||||||
@@ -189,8 +188,8 @@ class SmallOps(DecodeMoeOps):
|
|||||||
x=[expand_x],
|
x=[expand_x],
|
||||||
weight=[self.gmm1_weight],
|
weight=[self.gmm1_weight],
|
||||||
split_item=3,
|
split_item=3,
|
||||||
group_list_type=1, # Default is 0, represents prefix sum format
|
group_list_type=1, # 默认为0,代表前缀和形式
|
||||||
group_type=0, # 0 represents m-axis grouping
|
group_type=0, # 0代表m轴分组
|
||||||
group_list=expert_token_nums,
|
group_list=expert_token_nums,
|
||||||
output_dtype=torch.int32)[0]
|
output_dtype=torch.int32)[0]
|
||||||
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
@@ -366,7 +365,7 @@ def run_once(local_rank_id,
|
|||||||
with_mc2_mask=False):
|
with_mc2_mask=False):
|
||||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||||
) if output_to_file(local_rank_id) else None
|
) if output_to_file(local_rank_id) else None
|
||||||
global_rank_id = local_rank_id # Single machine
|
global_rank_id = local_rank_id # 单机
|
||||||
device_id = local_rank_id % 16
|
device_id = local_rank_id % 16
|
||||||
torch_npu.npu.set_device(device_id)
|
torch_npu.npu.set_device(device_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,322 +0,0 @@
|
|||||||
import itertools
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch_npu.testing.testcase import TestCase, run_tests
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm_ascend.utils import enable_custom_op
|
|
||||||
enable_custom_op()
|
|
||||||
except ImportError:
|
|
||||||
logging.warning(
|
|
||||||
"vllm_ascend.utils.enable_custom_op not found, skip custom op initialization"
|
|
||||||
)
|
|
||||||
|
|
||||||
def enable_custom_op() -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Set random seed for reproducibility
|
|
||||||
SEED = 45
|
|
||||||
random.seed(SEED)
|
|
||||||
np.random.seed(SEED)
|
|
||||||
torch.manual_seed(SEED)
|
|
||||||
if hasattr(torch, "npu") and torch.npu.is_available():
|
|
||||||
torch.npu.manual_seed_all(SEED)
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def softmax_func(
|
|
||||||
x: np.ndarray,
|
|
||||||
axis: Optional[int] = None,
|
|
||||||
eps: float = 1e-20) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Stable softmax implementation for MOE gating.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input array
|
|
||||||
axis: Axis to compute softmax
|
|
||||||
eps: Epsilon to avoid division by zero
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
softmax_output: Softmax result
|
|
||||||
x_max: Max value for numerical stability
|
|
||||||
x_sum: Sum of exponentials
|
|
||||||
"""
|
|
||||||
if "float16" in x.dtype.name:
|
|
||||||
x = x.astype(np.float32)
|
|
||||||
|
|
||||||
x_max = x.max(axis=axis, keepdims=True)
|
|
||||||
x_sub = x - x_max
|
|
||||||
y = np.exp(x_sub)
|
|
||||||
x_sum = y.sum(axis=axis, keepdims=True)
|
|
||||||
softmax_output = y / (x_sum + eps)
|
|
||||||
|
|
||||||
return softmax_output, x_max, x_sum
|
|
||||||
|
|
||||||
|
|
||||||
class TestNpuMoeGatingTopK(TestCase):
|
|
||||||
"""Test suite for NPU MOE Gating Top-K operator compatibility."""
|
|
||||||
|
|
||||||
def moe_gating_top_k_np(
|
|
||||||
self,
|
|
||||||
x: np.ndarray,
|
|
||||||
k: int,
|
|
||||||
bias: Optional[np.ndarray] = None,
|
|
||||||
k_group: int = 1,
|
|
||||||
group_count: int = 1,
|
|
||||||
group_select_mode: int = 0,
|
|
||||||
renorm: int = 0,
|
|
||||||
norm_type: int = 0,
|
|
||||||
y2_flag: bool = False,
|
|
||||||
routed_scaling_factor: float = 1.0,
|
|
||||||
eps: float = 1e-20
|
|
||||||
) -> Tuple[torch.Tensor, np.ndarray, Optional[np.ndarray]]:
|
|
||||||
"""
|
|
||||||
NumPy reference implementation of MOE gating Top-K logic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input features, shape [batch_size, num_experts]
|
|
||||||
k: Number of experts to select per sample
|
|
||||||
bias: Gating bias, shape [num_experts]
|
|
||||||
k_group: Number of groups to select (group mode)
|
|
||||||
group_count: Number of expert groups
|
|
||||||
group_select_mode: 0 (max per group), 1 (sum of top-2 per group)
|
|
||||||
renorm: Whether to renormalize weights (1=enable, 0=disable)
|
|
||||||
norm_type: 0 (softmax), 1 (sigmoid)
|
|
||||||
y2_flag: Whether to return original x as y2
|
|
||||||
routed_scaling_factor: Weight scaling factor
|
|
||||||
eps: Epsilon for numerical stability
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
y: Selected expert weights (Tensor)
|
|
||||||
indices: Selected expert indices (int32 numpy array)
|
|
||||||
y2: Original x if y2_flag=True, else None
|
|
||||||
"""
|
|
||||||
# Convert torch tensors to numpy arrays if needed (compatibility layer)
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
x = x.cpu().numpy()
|
|
||||||
if isinstance(bias, torch.Tensor):
|
|
||||||
bias = bias.cpu().numpy()
|
|
||||||
|
|
||||||
# Type conversion for numerical stability
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
if orig_dtype != np.float32:
|
|
||||||
x = x.astype(np.float32)
|
|
||||||
if bias is not None:
|
|
||||||
bias = bias.astype(np.float32)
|
|
||||||
|
|
||||||
# Apply normalization (softmax/sigmoid)
|
|
||||||
if norm_type == 0:
|
|
||||||
x, _, _ = softmax_func(x, axis=-1, eps=eps)
|
|
||||||
else:
|
|
||||||
x = 1 / (1 + np.exp(-x)) # Sigmoid
|
|
||||||
|
|
||||||
original_x = x.copy()
|
|
||||||
|
|
||||||
# Apply bias if provided
|
|
||||||
if bias is not None:
|
|
||||||
x = x + bias
|
|
||||||
|
|
||||||
# Group-based expert selection
|
|
||||||
if group_count > 1:
|
|
||||||
batch_size, num_experts = x.shape
|
|
||||||
if num_experts % group_count != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"num_experts ({num_experts}) must be divisible by group_count ({group_count})"
|
|
||||||
)
|
|
||||||
group_size = num_experts // group_count
|
|
||||||
|
|
||||||
# Reshape to [batch, groups, group_size]
|
|
||||||
x_reshaped = x.reshape(batch_size, group_count, group_size)
|
|
||||||
|
|
||||||
# Compute group scores
|
|
||||||
if group_select_mode == 0:
|
|
||||||
group_scores = np.amax(x_reshaped, axis=-1)
|
|
||||||
else:
|
|
||||||
# Sum of top-2 values per group
|
|
||||||
group_scores = np.partition(x_reshaped, -2,
|
|
||||||
axis=-1)[..., -2:].sum(axis=-1)
|
|
||||||
|
|
||||||
# Select top-k_group groups
|
|
||||||
top_groups = np.argsort(-group_scores, axis=-1,
|
|
||||||
kind="stable")[:, :k_group]
|
|
||||||
|
|
||||||
# Mask out non-selected groups with -inf
|
|
||||||
mask = np.ones((batch_size, group_count), dtype=bool)
|
|
||||||
mask[np.arange(batch_size)[:, None], top_groups] = False
|
|
||||||
x_reshaped = np.where(mask[..., None], float("-inf"), x_reshaped)
|
|
||||||
|
|
||||||
# Reshape back to original
|
|
||||||
x = x_reshaped.reshape(batch_size, num_experts)
|
|
||||||
|
|
||||||
# Select top-k experts
|
|
||||||
x_tensor = torch.from_numpy(x)
|
|
||||||
_, topk_indices = torch.sort(x_tensor,
|
|
||||||
dim=-1,
|
|
||||||
stable=True,
|
|
||||||
descending=True)
|
|
||||||
topk_indices = np.asarray(topk_indices[:, :k], dtype=np.int32)
|
|
||||||
|
|
||||||
# Extract weights for selected experts
|
|
||||||
selected_weights = np.take_along_axis(original_x, topk_indices, axis=1)
|
|
||||||
|
|
||||||
# Apply renormalization if needed
|
|
||||||
if norm_type == 1 or renorm == 1:
|
|
||||||
weight_sum = np.sum(selected_weights, axis=-1, keepdims=True)
|
|
||||||
selected_weights = selected_weights / (weight_sum + eps)
|
|
||||||
|
|
||||||
# Apply scaling factor
|
|
||||||
selected_weights *= routed_scaling_factor
|
|
||||||
|
|
||||||
# Prepare y2 output
|
|
||||||
y2 = original_x if y2_flag else None
|
|
||||||
|
|
||||||
# Convert back to torch tensor with original dtype
|
|
||||||
selected_weights_tensor = torch.tensor(selected_weights,
|
|
||||||
dtype=orig_dtype)
|
|
||||||
|
|
||||||
return selected_weights_tensor, topk_indices, y2
|
|
||||||
|
|
||||||
def test_npu_moe_gating_topk_multi(self) -> None:
|
|
||||||
"""
|
|
||||||
Multi-case test for NPU MOE Gating Top-K operator.
|
|
||||||
Validates compatibility with different input shapes and parameter combinations.
|
|
||||||
"""
|
|
||||||
# Test parameter space (aligned with vllm-ascend use cases)
|
|
||||||
test_configs = {
|
|
||||||
"group_select_modes": [0, 1],
|
|
||||||
"renorms": [1],
|
|
||||||
"norm_types": [0, 1],
|
|
||||||
"group_counts": [1, 8],
|
|
||||||
"k_ranges": [4, 8, 12, 16, 6, 32],
|
|
||||||
"x_dim0": range(1, 17), # Batch size 1-16
|
|
||||||
"x_dim1": [256, 128, 64, 208, 192, 160] # Expert counts
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate parameter combinations
|
|
||||||
param_combinations = itertools.product(
|
|
||||||
test_configs["group_select_modes"], test_configs["renorms"],
|
|
||||||
test_configs["norm_types"], test_configs["group_counts"],
|
|
||||||
test_configs["k_ranges"], test_configs["x_dim0"],
|
|
||||||
test_configs["x_dim1"])
|
|
||||||
|
|
||||||
# Limit test cases to avoid excessive runtime (adjust as needed)
|
|
||||||
max_test_cases = 100
|
|
||||||
tested_cases = 0
|
|
||||||
|
|
||||||
for params in param_combinations:
|
|
||||||
if tested_cases >= max_test_cases:
|
|
||||||
break
|
|
||||||
|
|
||||||
(group_select_mode, renorm, norm_type, group_count, k, dim0,
|
|
||||||
dim1) = params
|
|
||||||
|
|
||||||
# Skip invalid configurations
|
|
||||||
if group_count > 1:
|
|
||||||
if dim1 % group_count != 0:
|
|
||||||
continue
|
|
||||||
if k > (dim1 // group_count):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Generate random inputs (consistent with vllm-ascend input distribution)
|
|
||||||
x_np = np.random.uniform(-2.0, 2.0,
|
|
||||||
(dim0, dim1)).astype(np.float32)
|
|
||||||
bias_np = np.random.uniform(-2.0, 2.0, (dim1, )).astype(np.float32)
|
|
||||||
|
|
||||||
# Convert to torch tensors
|
|
||||||
x_tensor = torch.tensor(x_np, dtype=torch.float32)
|
|
||||||
bias_tensor = torch.tensor(bias_np, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Random k_group (within valid range)
|
|
||||||
k_group = random.randint(1, min(group_count, 4))
|
|
||||||
|
|
||||||
# Fixed parameters (aligned with NPU OP defaults)
|
|
||||||
y2_flag = False
|
|
||||||
routed_scaling_factor = 1.0
|
|
||||||
eps = 1e-20
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get NumPy reference result
|
|
||||||
ref_weights, ref_indices, ref_y2 = self.moe_gating_top_k_np(
|
|
||||||
x=x_tensor,
|
|
||||||
k=k,
|
|
||||||
bias=bias_tensor,
|
|
||||||
k_group=k_group,
|
|
||||||
group_count=group_count,
|
|
||||||
group_select_mode=group_select_mode,
|
|
||||||
renorm=renorm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
y2_flag=y2_flag,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
eps=eps)
|
|
||||||
|
|
||||||
# Skip if NPU OP is not available
|
|
||||||
if not hasattr(torch.ops, "_C_ascend") or not hasattr(
|
|
||||||
torch.ops._C_ascend, "moe_gating_top_k"):
|
|
||||||
logger.warning(
|
|
||||||
"NPU MOE gating OP not found, skipping NPU test")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get NPU OP result
|
|
||||||
npu_weights, npu_indices, npu_y2 = torch.ops._C_ascend.moe_gating_top_k(
|
|
||||||
x=x_tensor.npu(),
|
|
||||||
k=k,
|
|
||||||
kGroup=k_group,
|
|
||||||
groupCount=group_count,
|
|
||||||
groupSelectMode=group_select_mode,
|
|
||||||
renorm=renorm,
|
|
||||||
normType=norm_type,
|
|
||||||
outFlag=y2_flag,
|
|
||||||
routedScalingFactor=routed_scaling_factor,
|
|
||||||
eps=eps,
|
|
||||||
biasOptional=bias_tensor.npu()
|
|
||||||
if bias_tensor is not None else None)
|
|
||||||
|
|
||||||
# Convert NPU results to CPU for comparison
|
|
||||||
npu_weights_cpu = npu_weights.cpu()
|
|
||||||
npu_indices_cpu = npu_indices.cpu().numpy()
|
|
||||||
|
|
||||||
# Log test case info (vllm-ascend standard format)
|
|
||||||
logger.info(
|
|
||||||
f"Test Case {tested_cases + 1}: "
|
|
||||||
f"x_shape=({dim0},{dim1}), k={k}, group_count={group_count}, "
|
|
||||||
f"select_mode={group_select_mode}, norm_type={norm_type}, renorm={renorm}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate results (RTOL=1e-3 is standard for NPU numerical tolerance)
|
|
||||||
self.assertRtolEqual(ref_weights,
|
|
||||||
npu_weights_cpu,
|
|
||||||
rtol=1e-3,
|
|
||||||
atol=1e-5)
|
|
||||||
self.assertRtolEqual(ref_indices, npu_indices_cpu)
|
|
||||||
|
|
||||||
# Validate y2 if enabled
|
|
||||||
if y2_flag:
|
|
||||||
self.assertRtolEqual(ref_y2,
|
|
||||||
npu_y2.cpu().numpy(),
|
|
||||||
rtol=1e-3,
|
|
||||||
atol=1e-5)
|
|
||||||
|
|
||||||
tested_cases += 1
|
|
||||||
logger.info(f"Test Case {tested_cases} passed ")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Test Case failed with error: {str(e)}",
|
|
||||||
exc_info=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Completed {tested_cases}/{max_test_cases} test cases")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Run tests with vllm-ascend standard verbosity
|
|
||||||
run_tests(verbosity=2)
|
|
||||||
@@ -311,7 +311,7 @@ def test_client_handler_mismatch(server_config):
|
|||||||
mismatch_data = {
|
mismatch_data = {
|
||||||
"label": "JOIN",
|
"label": "JOIN",
|
||||||
"content": {
|
"content": {
|
||||||
"device_id": 1, # Mismatched ID
|
"device_id": 1, # 不匹配的ID
|
||||||
"model_path": "/wrong/model",
|
"model_path": "/wrong/model",
|
||||||
"tp": 2,
|
"tp": 2,
|
||||||
"pp": 2,
|
"pp": 2,
|
||||||
|
|||||||
@@ -670,7 +670,7 @@ class TestNPUWorker(TestBase):
|
|||||||
(5000, 10000),
|
(5000, 10000),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create worker mock
|
# 创建 worker mock
|
||||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||||
worker = NPUWorker()
|
worker = NPUWorker()
|
||||||
worker.init_npu_memory = 8500
|
worker.init_npu_memory = 8500
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
from vllm_ascend.utils import get_weight_prefetch_method
|
from vllm_ascend.utils import get_weight_prefetch_method
|
||||||
|
|
||||||
@@ -213,19 +214,21 @@ def _select_experts_with_fusion_ops(
|
|||||||
e_score_correction_bias.dtype != router_logits.dtype:
|
e_score_correction_bias.dtype != router_logits.dtype:
|
||||||
e_score_correction_bias = e_score_correction_bias.to(
|
e_score_correction_bias = e_score_correction_bias.to(
|
||||||
router_logits.dtype)
|
router_logits.dtype)
|
||||||
_, topk_ids, topk_weights = torch.ops._C_ascend.moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=top_k,
|
k=top_k,
|
||||||
kGroup=topk_group,
|
bias=e_score_correction_bias,
|
||||||
groupCount=num_expert_group,
|
k_group=topk_group,
|
||||||
groupSelectMode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
group_count=num_expert_group,
|
||||||
renorm=1, # 0: softmax->topk(fix); 1: topk->softmax
|
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
normType=norm_type, # 0: softmax; 1: sigmoid
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
outFlag=False, # todo new api; should the third output be output
|
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||||
routedScalingFactor=1,
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
eps=float(1e-20),
|
# y2_flag=False, # old api; should the third output be output
|
||||||
biasOptional=e_score_correction_bias,
|
routed_scaling_factor=1,
|
||||||
)
|
eps=float(1e-20))
|
||||||
|
if scoring_func == "softmax":
|
||||||
|
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user