From 3be8e33fe9a78529e6202cafcfaaffdede414441 Mon Sep 17 00:00:00 2001 From: ZCG12345 <2097562023@qq.com> Date: Wed, 7 Jan 2026 21:42:31 +0800 Subject: [PATCH] [Kernel] Add moe_gating_top_k operator support for Ascend NPU (#5579) ### What this PR does / why we need it? 1.replace moe_gating_top_k from torch_npu with custom op 2.enable the renorm function of moe_gating_top_k in softmax scenerio ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? No need test - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/7157596103666ee7ccb7008acee8bff8a8ff1731 --------- Signed-off-by: ZCG12345 <2097562023@qq.com> --- csrc/build_aclnn.sh | 3 +- csrc/moe_gating_top_k/op_host/CMakeLists.txt | 42 ++ csrc/moe_gating_top_k/op_host/error_log.h | 56 ++ csrc/moe_gating_top_k/op_host/math_util.h | 61 ++ .../op_host/moe_gating_top_k_def.cpp | 71 ++ .../op_host/moe_gating_top_k_infershape.cpp | 147 ++++ .../op_host/moe_gating_top_k_proto.cpp | 15 + .../op_host/moe_gating_top_k_proto.h | 66 ++ .../op_host/moe_gating_top_k_tiling.cpp | 573 +++++++++++++++ .../op_host/moe_gating_top_k_tiling.h | 86 +++ .../moe_gating_top_k_tiling_arch35.cpp | 521 ++++++++++++++ .../op_host/moe_gating_top_k_tiling_base.cpp | 38 + csrc/moe_gating_top_k/op_kernel/common.h | 89 +++ csrc/moe_gating_top_k/op_kernel/error_log.h | 55 ++ .../op_kernel/moe_gating_top_k.cpp | 63 ++ .../op_kernel/moe_gating_top_k_apt.cpp | 46 ++ .../op_kernel/moe_gating_top_k_e_k_fullload.h | 404 +++++++++++ .../op_kernel/moe_gating_top_k_generalized.h | 669 ++++++++++++++++++ .../moe_gating_top_k_without_group.h | 338 +++++++++ .../tiling_base/data_copy_transpose_tiling.h | 51 ++ .../data_copy_transpose_tiling_def.h | 43 ++ csrc/moe_gating_top_k/tiling_base/error_log.h | 56 ++ .../tiling_base/tiling_base.h | 256 +++++++ .../moe_gating_top_k/tiling_base/tiling_key.h | 63 ++ .../tiling_base/tiling_templates_registry.h | 351 +++++++++ .../tiling_base/tiling_type.h | 139 ++++ .../tiling_base/tiling_util.h | 30 + csrc/torch_binding.cpp | 73 ++ csrc/torch_binding_meta.cpp | 39 + .../ops/singlecard_ops/test_fused_moe.py | 4 +- .../test_npu_moe_gating_top_k.py | 210 ++++++ vllm_ascend/ops/fused_moe/experts_selector.py | 22 +- 32 files changed, 4667 insertions(+), 13 deletions(-) create mode 100644 csrc/moe_gating_top_k/op_host/CMakeLists.txt create mode 100644 csrc/moe_gating_top_k/op_host/error_log.h create mode 100644 csrc/moe_gating_top_k/op_host/math_util.h create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp create mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp create mode 100644 csrc/moe_gating_top_k/op_kernel/common.h create mode 100644 csrc/moe_gating_top_k/op_kernel/error_log.h create mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp create mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp create mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h create mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h create mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h create mode 100644 csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h create mode 100644 csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h create mode 100644 csrc/moe_gating_top_k/tiling_base/error_log.h create mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_base.h create mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_key.h create mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h create mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_type.h create mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_util.h create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/test_npu_moe_gating_top_k.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index f24a9f02..7eba981b 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -70,6 +70,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "dispatch_layout" "notify_dispatch" "moe_init_routing_custom" + "moe_gating_top_k" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/moe_gating_top_k/op_host/CMakeLists.txt b/csrc/moe_gating_top_k/op_host/CMakeLists.txt new file mode 100644 index 00000000..0a6f51db --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/CMakeLists.txt @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# This program is free software, you can redistribute it and/or modify. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ---------------------------------------------------------------------------- + +add_ops_compile_options( + OP_NAME MoeGatingTopK + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror +) + +# Host (aclnn) +if (BUILD_OPEN_PROJECT) + target_sources(op_host_aclnn PRIVATE + moe_gating_top_k_def.cpp + + ) + + # Tiling + target_sources(optiling PRIVATE + moe_gating_top_k_tiling.cpp + moe_gating_top_k_tiling_base.cpp + moe_gating_top_k_tiling_arch35.cpp + ) + + target_sources(opsproto PRIVATE + moe_gating_top_k_proto.cpp + moe_gating_top_k_infershape.cpp + + ) + + target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ) + +endif() diff --git a/csrc/moe_gating_top_k/op_host/error_log.h b/csrc/moe_gating_top_k/op_host/error_log.h new file mode 100644 index 00000000..724775a7 --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/error_log.h @@ -0,0 +1,56 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + + +#define OP_CHECK_IF(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) + + + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if ((ptr) == nullptr) { \ + OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/op_host/math_util.h b/csrc/moe_gating_top_k/op_host/math_util.h new file mode 100644 index 00000000..edc1c8ea --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/math_util.h @@ -0,0 +1,61 @@ +/** +* Copyright (c) 2025 Huawei Technologies Co., Ltd. +* This program is free software, you can redistribute it and/or modify it under the terms and conditions of +* CANN Open Software License Agreement Version 2.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +*/ + +/*! + * \file math_util.h + * \brief + */ + +#ifndef TILING_MATMUL_MATH_UTIL_H +#define TILING_MATMUL_MATH_UTIL_H + +#include +#include +#include +#include +namespace matmul_tiling { +class MathUtil { +public: + static bool IsEqual(float leftValue, float rightValue); + template + static auto CeilDivision(T num1, T num2) -> T + { + if (num2 == 0) { + return 0; + } + return static_cast((static_cast(num1) + static_cast(num2) - 1) / + static_cast(num2)); + } + template + static auto Align(T num1, T num2) -> T + { + return CeilDivision(num1, num2) * num2; + } + static int32_t AlignDown(int32_t num1, int32_t num2); + static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c); + static int32_t MapShape(int32_t shape, bool roundUpFlag = true); + static void AddFactor(std::vector &dimsFactors, int32_t dim); + static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart, + const int32_t factorEnd); + static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart, + const int32_t factorEnd); + static bool CheckFactorNumSatisfy(const int32_t dim); + static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum, + bool isKDim); + static void GetFactors(std::vector &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor); + static void GetFactors(std::vector &factorList, int32_t srcNum, int32_t maxFactor); + static void GetBlockFactors(std::vector &factorList, const int32_t oriShape, const int32_t mpShape, + const int32_t coreNum, const int32_t maxNum); + static int32_t GetNonFactorMap(std::vector &factorList, int32_t srcNum, int32_t maxFactor); + static std::vector> GetFactorPairs(int32_t num); + static std::pair DivideIntoMainAndTail(int32_t num, int32_t divisor); +}; +} // namespace matmul_tiling +#endif // _MATH_UTIL_H_ diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp new file mode 100644 index 00000000..52ecf37d --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp @@ -0,0 +1,71 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_def.cpp + * \brief + */ +#include "register/op_def_registry.h" + +namespace ops { +class MoeGatingTopK : public OpDef { +public: + explicit MoeGatingTopK(const char *name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("bias") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expert_idx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("k").Int(); + this->Attr("k_group").AttrType(OPTIONAL).Int(1); + this->Attr("group_count").AttrType(OPTIONAL).Int(1); + this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0); + this->Attr("renorm").AttrType(OPTIONAL).Int(0); + this->Attr("norm_type").AttrType(OPTIONAL).Int(0); + this->Attr("out_flag").AttrType(OPTIONAL).Bool(false); + this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0); + this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + + OpAICoreConfig regbaseCfg; + regbaseCfg.DynamicCompileStaticFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt"); + this->AICore().AddConfig("ascend910_95", regbaseCfg); + } +}; + +OP_ADD(MoeGatingTopK); +} // namespace ops \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp new file mode 100644 index 00000000..37b72d1c --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp @@ -0,0 +1,147 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_gating_top_k_infershape.cpp + * \brief + */ + +#include "exe_graph/runtime/infer_shape_context.h" +#include "register/op_impl_registry.h" +#include "error_log.h" + +#include + +#include +#define TO_STRING(x) std::string(#x) + +using namespace ge; +namespace ops { +static constexpr size_t DIM_ONE = 1; +static constexpr size_t DIM_TWO = 2; +static constexpr int64_t NEG_ONE = -1; +static constexpr int64_t X_INDEX = 0; +static constexpr int64_t BIAS_INDEX = 1; +static constexpr int64_t Y_INDEX = 0; +static constexpr int64_t EXPERT_IDX_INDEX = 1; +static constexpr int64_t OUT_INDEX = 2; + +static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape) +{ + int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0); + int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1); + if (XRows < NEG_ONE || expertNum < NEG_ONE) { + OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape, + const int64_t k) +{ + if (xShape->GetDimNum() == 1U) { + if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) { + OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.", + TO_STRING(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + } else if (xShape->GetDimNum() != DIM_TWO) { + OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.", + TO_STRING(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + + if (k < 0) { + OP_LOGE(context, "k must be a non-negative number."); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k) +{ + OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str()); + OP_LOGD(context, "k is: %ld.", k); +} + +static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape, + const gert::Shape *expertIdxShape, const gert::Shape *outShape) +{ + OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str()); + OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str()); + OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str()); +} + +static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context) +{ + OP_LOGD(context, "Begin to do MoeGatingTopKInfershape."); + + // 获取输入shape + const gert::Shape *xShape = context->GetInputShape(0); + OP_CHECK_NULL_WITH_CONTEXT(context, xShape); + gert::Shape *yShape = context->GetOutputShape(0); + OP_CHECK_NULL_WITH_CONTEXT(context, yShape); + gert::Shape *expertIdxShape = context->GetOutputShape(1); + OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape); + gert::Shape *outShape = context->GetOutputShape(2); + OP_CHECK_NULL_WITH_CONTEXT(context, outShape); + + // 获取attr + auto attrs = context->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context, attrs); + const int64_t *kPtr = attrs->GetAttrPointer(0); + OP_CHECK_NULL_WITH_CONTEXT(context, kPtr); + const int64_t k = *kPtr; + ShowInputShapeInfo(context, xShape, k); + + // 参数校验 + if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0); + int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1); + + yShape->SetDimNum(DIM_TWO); + yShape->SetDim(0U, rows); + yShape->SetDim(1U, k); + + expertIdxShape->SetDimNum(DIM_TWO); + expertIdxShape->SetDim(0U, rows); + expertIdxShape->SetDim(1U, k); + + outShape->SetDimNum(DIM_TWO); + outShape->SetDim(0U, rows); + outShape->SetDim(1U, expertNum); + + ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape); + OP_LOGD(context, "End to do MoeGatingTopKInfershape."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context) +{ + OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType."); + auto xDtype = context->GetInputDataType(0); + context->SetOutputDataType(Y_INDEX, xDtype); + context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32); + context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT); + OP_LOGD(context, "End to do MoeGatingTopKInferDataType."); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK); +} // namespace ops diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp new file mode 100644 index 00000000..f10adf71 --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp @@ -0,0 +1,15 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_proto.h + * \brief + */ +#include "moe_gating_top_k_proto.h" \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h new file mode 100644 index 00000000..068acdfc --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h @@ -0,0 +1,66 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_proto.h + * \brief + */ +#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ +#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ + +#include "graph/operator_reg.h" + +namespace ge { + +/** + * @brief Compute renorm(sigmoid) and topk for moe input. + * + * @par Inputs: + * @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count. + * @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x. + * + * @par Outputs: + * @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x. + The size of the non-1 axis must be the same as that of the corresponding axis of x. + The size of the -1 axis must be the same as that of k. + * @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y. + * @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x. + * + * @par Attributes: + * @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value. + * @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1. + * @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1. + * @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0. + * @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0. + * @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0. + * @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false. + * @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0. + * @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20. + */ +REG_OP(MoeGatingTopK) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) + .OUTPUT(expert_idx, TensorType({DT_INT32})) + .OUTPUT(out, TensorType({DT_FLOAT})) + .REQUIRED_ATTR(k, Int) + .ATTR(k_group, Int, 1) + .ATTR(group_count, Int, 1) + .ATTR(group_select_mode, Int, 0) + .ATTR(renorm, Int, 0) + .ATTR(norm_type, Int, 0) + .ATTR(out_flag, Bool, false) + .ATTR(routed_scaling_factor, Float, 1.0) + .ATTR(eps, Float, 1e-20f) + .OP_END_FACTORY_REG(MoeGatingTopK) + +} // namespace ge + +#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp new file mode 100644 index 00000000..1e17b616 --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp @@ -0,0 +1,573 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_gating_top_k_tiling.cpp + * \brief + */ +#include +#include "register/op_def_registry.h" +#include "exe_graph/runtime/infer_shape_context.h" +#include "register/op_impl_registry.h" +#include "../tiling_base/tiling_base.h" +#include "../tiling_base/tiling_templates_registry.h" +#include "platform/platform_info.h" + + + +#include "error_log.h" +#include "moe_gating_top_k_tiling.h" + + +#ifndef CEIL_ALIGN +#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align)) +#endif + +#ifndef CEIL_DIV +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) +#endif +namespace optiling { +const static int64_t GROUP_SELECT_MODE_MAX = 0; +const static int64_t GROUP_SELECT_MODE_SUM = 1; +const static int64_t RENORM_NO = 0; +const static int64_t RENORM_L1 = 1; +const static int64_t NORM_TYPE_SOFTMAX = 0; +const static int64_t NORM_TYPE_SIGMOID = 1; +const static int64_t OUT_FLAG_FALSE = 0; +const static int64_t OUT_FLAG_TRUE = 1; +const static size_t X_INPUT_DIMS = 2; +const static size_t BIAS_INPUT_DIMS = 1; +const static size_t Y_OUTPUT_DIMS = 2; +const static size_t EXPERT_IDX_OUTPUY_DIMS = 2; +const static size_t OUT_OUTPUT_DIMS = 2; +const static int64_t MAX_EXPERT_COUNT = 2048; + +const static int64_t X_INPUT_INDEX = 0; +const static int64_t BIAS_INPUT_INDEX = 1; +const static int64_t Y_OUTPUT_INDEX = 0; +const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1; +const static int64_t OUT_OUTPUT_INDEX = 2; +const static int64_t K_ATTR_INDEX = 0; +const static int64_t K_GROUP_ATTR_INDEX = 1; +const static int64_t GROUP_COUNT_ATTR_INDEX = 2; +const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3; +const static int64_t RENORM_ATTR_INDEX = 4; +const static int64_t NORM_TYPE_ATTR_INDEX = 5; +const static int64_t OUT_FLAG_ATTR_INDEX = 6; +const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7; +const static int64_t EPS_ATTR_INDEX = 8; +const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216; +const static uint32_t DATATYPESIZE_FLOAT = 4; +const static bool IS_LARGEST = true; +const static bool IS_INITINDEX = false; +const static bool IS_REUSESOURCE = false; +const static uint64_t WITH_GROUP_CONDITION = 1; +const static uint64_t WITHOUT_GROUP_CONDITION = 2; +const static uint64_t MAX_IN_GROUP_CONDITION = 3; +constexpr int32_t ROW_COUNT_PER_TASK = 1; + +const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0; +const static uint64_t TILING_KEY_WITHOUT_GROUP = 1; +const static uint64_t TILING_KEY_GENERALIZED = 2; + +inline static int64_t CeilLog4(int64_t x) +{ + return static_cast(std::ceil(std::log(x) / std::log(4))); // 4 for four +} + +class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass { +public: + explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context) + { + Reset(); + } + ~MoeGatingTopKTilingBase() override = default; + + void Reset(gert::TilingContext *context) override + { + TilingBaseClass::Reset(context); + Reset(); + } + +protected: + bool IsCapable() override + { + return true; + } + + ge::graphStatus GetPlatformInfo() override; + + ge::graphStatus GetShapeAttrsInfo() override; + + ge::graphStatus DoOpTiling() override; + + ge::graphStatus DoLibApiTiling() override; + + uint64_t GetTilingKey() const override; + + ge::graphStatus GetWorkspaceSize() override; + + ge::graphStatus PostTiling() override; + void Reset(); + +private: + ge::graphStatus CheckInputShape(); + ge::graphStatus CheckAttr(); + ge::graphStatus CheckOutShape(); + void SplitRows(); + void CalTmpBufUbSize(); + + const gert::Shape *xShape_ = nullptr; + const gert::Shape *biasShape_ = nullptr; + const gert::Shape *yShape_ = nullptr; + const gert::Shape *expertIdxShape_ = nullptr; + const gert::Shape *outShape_ = nullptr; + + int64_t rows_ = 0; + int64_t expertCount_ = 0; + int64_t addBias_ = 0; + + int64_t k_ = 0; + int64_t kGroup_ = 0; + int64_t groupCount_ = 0; + int64_t perGroupExpertCount_ = 0; + int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX; + int64_t renorm_ = RENORM_NO; + int64_t normType_ = NORM_TYPE_SOFTMAX; + int64_t outFlag_ = OUT_FLAG_FALSE; + float routedScalingFactor_ = 1.0; + float eps_ = 1e-20f; + + int64_t inputDtypeSize_; + const char *opName_ = ""; + MoeGatingTopKTilingData moeGatingTopKTilingData_; +}; + +ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape() +{ + size_t xDimNum = xShape_->GetDimNum(); + + OP_CHECK_IF(xDimNum != X_INPUT_DIMS, + + OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS), + return ge::GRAPH_FAILED); + + + rows_ = xShape_->GetDim(0); + expertCount_ = xShape_->GetDim(1); + + moeGatingTopKTilingData_.set_rowCount(rows_); + moeGatingTopKTilingData_.set_expertCount(expertCount_); + if (biasShape_ != nullptr) { + addBias_ = 1; + size_t biasDimNum = biasShape_->GetDimNum(); + OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS, + OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS), + return ge::GRAPH_FAILED); + OP_CHECK_IF( + biasShape_->GetDim(0) != expertCount_, + OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_), + return ge::GRAPH_FAILED); + + } + moeGatingTopKTilingData_.set_addBias(addBias_); + + OP_CHECK_IF(k_ > expertCount_, + OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::CheckAttr() +{ + OP_CHECK_IF( + expertCount_ > MAX_EXPERT_COUNT, + OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED); + + OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(kGroup_ > groupCount_, + OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID, + OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_, + NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX, + OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_, + GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1, + OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(expertCount_ % groupCount_ != 0, + OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_), + return ge::GRAPH_FAILED); + + perGroupExpertCount_ = expertCount_ / groupCount_; + + OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_); + + OP_CHECK_IF(perGroupExpertCount_ < 1, + OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_), + return ge::GRAPH_FAILED); + OP_CHECK_IF( + groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2, + OP_LOGE(context_, + "group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.", + perGroupExpertCount_, groupSelectMode_), + return ge::GRAPH_FAILED); + OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_, + OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_), + return ge::GRAPH_FAILED); + int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L); + OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign); + if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) { + + OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT, + OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.", + groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT), + return ge::GRAPH_FAILED); + } + + moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_); + moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo() +{ + opName_ = context_->GetNodeName(); + OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_); + auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr); + xShape_ = &xShapePtr->GetStorageShape(); + OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str()); + + auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX); + biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape(); + if (biasShape_ != nullptr) { + OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str()); + } + + auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr); + yShape_ = &yShapePtr->GetStorageShape(); + OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str()); + auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr); + expertIdxShape_ = &expertIdxPtr->GetStorageShape(); + OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str()); + auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr); + outShape_ = &outPtr->GetStorageShape(); + if (outShape_ != nullptr) { + OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str()); + } + + auto x = context_->GetInputDesc(X_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, x); + auto xDtype = x->GetDataType(); + OP_CHECK_IF( + (xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16), + OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.", + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + + if (biasShapePtr != nullptr) { + auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType(); + OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str()); + OP_CHECK_IF((biasDtype != xDtype), + OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.", + ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(), + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + } + + auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc); + auto yDtype = yDesc->GetDataType(); + OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str()); + OP_CHECK_IF((yDtype != xDtype), + OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.", + ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(), + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + + auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc); + auto expertIdDtype = expertIdDesc->GetDataType(); + OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()); + OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32), + OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.", + ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()), + return ge::GRAPH_FAILED); + + auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc); + auto normOutDtype = normOutDesc->GetDataType(); + OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT), + OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.", + ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()), + return ge::GRAPH_FAILED); + + + auto attrs = context_->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context_, attrs); + + const int64_t *kPtr = attrs->GetAttrPointer(K_ATTR_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr); + k_ = *kPtr; + OP_LOGI(context_, "Attr k is: %ld", k_); + moeGatingTopKTilingData_.set_k(k_); + + OP_LOGI(context_, "Attr k is: %ld ", k_); + + const int64_t *kGroupPtr = attrs->GetAttrPointer(K_GROUP_ATTR_INDEX); + if (kGroupPtr != nullptr) { + kGroup_ = *kGroupPtr; + OP_LOGI(context_, "Attr k_group is: %ld", kGroup_); + moeGatingTopKTilingData_.set_kGroup(kGroup_); + } + OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_); + + const int64_t *groupCountPtr = attrs->GetAttrPointer(GROUP_COUNT_ATTR_INDEX); + if (groupCountPtr != nullptr) { + groupCount_ = *groupCountPtr; + OP_LOGI(context_, "Attr group_count is: %ld", groupCount_); + moeGatingTopKTilingData_.set_groupCount(groupCount_); + } + OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_); + + const int64_t *groupSelectModePtr = attrs->GetAttrPointer(GROUP_SELECT_MODE_ATTR_INDEX); + if (groupSelectModePtr != nullptr) { + groupSelectMode_ = *groupSelectModePtr; + OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_); + moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_); + } + OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_); + + const int64_t *renormPtr = attrs->GetAttrPointer(RENORM_ATTR_INDEX); + if (renormPtr != nullptr) { + renorm_ = *renormPtr; + OP_LOGI(context_, "Attr renorm is: %ld", renorm_); + moeGatingTopKTilingData_.set_renorm(renorm_); + } + OP_LOGI(context_, "Attr renorm is: %ld ", renorm_); + + const int64_t *normTypePtr = attrs->GetAttrPointer(NORM_TYPE_ATTR_INDEX); + if (normTypePtr != nullptr) { + normType_ = *normTypePtr; + OP_LOGI(context_, "Attr norm_type is: %ld", normType_); + moeGatingTopKTilingData_.set_normType(normType_); + } + OP_LOGI(context_, "Attr norm_type is: %ld ", normType_); + + const bool *outFlagPtr = attrs->GetAttrPointer(OUT_FLAG_ATTR_INDEX); + if (outFlagPtr != nullptr) { + outFlag_ = (*outFlagPtr) ? 1 : 0; + OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_); + moeGatingTopKTilingData_.set_outFlag(outFlag_); + } + OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_); + + const float *routedScalingFactorPtr = attrs->GetAttrPointer(ROUTED_SCALING_FACTOR_ATTR_INDEX); + if (routedScalingFactorPtr != nullptr) { + routedScalingFactor_ = *routedScalingFactorPtr; + OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_); + moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_); + } + OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_); + + const float *epsPtr = attrs->GetAttrPointer(EPS_ATTR_INDEX); + if (epsPtr != nullptr) { + eps_ = *epsPtr; + OP_LOGI(context_, "Attr eps is: %f", eps_); + moeGatingTopKTilingData_.set_eps(eps_); + } + OP_LOGI(context_, "Attr eps is: %f ", eps_); + + inputDtypeSize_ = static_cast(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType())); + OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo() + +{ + auto platformInfo = context_->GetPlatformInfo(); + OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSizePlatForm; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); + aicoreParams_.ubSize = ubSizePlatForm; + OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape() +{ + OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str()); + OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(), + xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.", + expertIdxShape_->GetDimNum(), xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + if (outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.", + outShape_->GetDimNum(), xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + } + + OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0), + xShape_->GetDim(0)), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.", + expertIdxShape_->GetDim(0), xShape_->GetDim(0)), + return ge::GRAPH_FAILED); + if (outFlag_ && outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.", + outShape_->GetDim(0), outShape_->GetDim(0)), + return ge::GRAPH_FAILED); + } + + OP_CHECK_IF((yShape_->GetDim(1) != k_), + OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_), + OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_), + return ge::GRAPH_FAILED); + if (outFlag_ && outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)), + OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1), + xShape_->GetDim(1)), + return ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +void MoeGatingTopKTilingBase::SplitRows() +{ + int64_t perCoreRows = CEIL_DIV(rows_, static_cast(aicoreParams_.blockDim)); + int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows); + // perCoreRows cannot be 0 + int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows; + moeGatingTopKTilingData_.set_needCoreNum(needCoreNum); + moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows); + moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows); + int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L)); + OP_LOGI(context_, "vms count is: %ld", vmsCount); + moeGatingTopKTilingData_.set_vmsCount(vmsCount); +} + +void MoeGatingTopKTilingBase::CalTmpBufUbSize() + +{ + + std::vector shape_vec = {expertCount_}; + ge::Shape shape(shape_vec); + uint32_t maxValue = 0; + uint32_t minValue = 0; + AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue); + + int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast(sizeof(float)); + moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast(minValue))); +} + +ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling() +{ + + OP_LOGI(context_, "DoOpTiling: start"); + auto ret = CheckInputShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckOutShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckAttr(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + CalTmpBufUbSize(); + SplitRows(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling() +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize() +{ + + workspaceSize_ = DEFAULT_WORKSPACE_SIZE; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingBase::PostTiling() +{ + context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum()); + size_t *currentWorkspace = context_->GetWorkspaceSizes(1); + currentWorkspace[0] = workspaceSize_; + moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), + context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize()); + return ge::GRAPH_SUCCESS; +} + +uint64_t MoeGatingTopKTilingBase::GetTilingKey() const +{ + + if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ && + groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID && + !outFlag_) { + + return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF; + } else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) { + return TILING_KEY_WITHOUT_GROUP; + } else { + return TILING_KEY_GENERALIZED; + } +} + +void MoeGatingTopKTilingBase::Reset() +{ + opName_ = nullptr; + return; +} + +REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000); +} // namespace optiling diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h new file mode 100644 index 00000000..6ac4aa46 --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h @@ -0,0 +1,86 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_tiling.h + * \brief + */ + +#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H +#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H + +#include +#include +#include +#include + + +#include "../tiling_base/tiling_base.h" +#include "../tiling_base/tiling_templates_registry.h" +#include "register/op_def_registry.h" +#include "register/op_impl_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error_log.h" + +#include "register/op_impl_registry.h" +#include "platform/platform_infos_def.h" +#include "math_util.h" +//#include "util/extern_math_util.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, rowCount); +TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount); +TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount); +TILING_DATA_FIELD_DEF(int64_t, expertCount); +TILING_DATA_FIELD_DEF(int64_t, addBias); +TILING_DATA_FIELD_DEF(int64_t, k); +TILING_DATA_FIELD_DEF(int64_t, kGroup); +TILING_DATA_FIELD_DEF(int64_t, groupCount); +TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount); +TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign); +TILING_DATA_FIELD_DEF(int64_t, groupSelectMode); +TILING_DATA_FIELD_DEF(int64_t, renorm); +TILING_DATA_FIELD_DEF(int64_t, normType); +TILING_DATA_FIELD_DEF(int64_t, outFlag); +TILING_DATA_FIELD_DEF(int64_t, vmsCount); +TILING_DATA_FIELD_DEF(float, routedScalingFactor); +TILING_DATA_FIELD_DEF(float, eps); +TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData) + +BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, rowCount); +TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount); +TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount); +TILING_DATA_FIELD_DEF(int64_t, expertCount); +TILING_DATA_FIELD_DEF(int64_t, addBias); +TILING_DATA_FIELD_DEF(int64_t, k); +TILING_DATA_FIELD_DEF(int64_t, kGroup); +TILING_DATA_FIELD_DEF(int64_t, groupCount); +TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount); +TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign); +TILING_DATA_FIELD_DEF(int64_t, groupSelectMode); +TILING_DATA_FIELD_DEF(int64_t, renorm); +TILING_DATA_FIELD_DEF(int64_t, normType); +TILING_DATA_FIELD_DEF(int64_t, outFlag); +TILING_DATA_FIELD_DEF(int64_t, vmsCount); +TILING_DATA_FIELD_DEF(float, routedScalingFactor); +TILING_DATA_FIELD_DEF(float, eps); +TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData) +struct MoeGatingTopKCompileInfo {}; +} // namespace optiling +#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp new file mode 100644 index 00000000..f170d3dd --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp @@ -0,0 +1,521 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_gating_top_k_tiling_arch35.cpp + * \brief + */ + +#include "error_log.h" +#include "moe_gating_top_k_tiling.h" +#include "register/op_def_registry.h" +#include "platform/platform_info.h" +#include "../tiling_base/tiling_base.h" +#include "../tiling_base/tiling_templates_registry.h" + +#ifndef CEIL_ALIGN +#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align)) +#endif +#ifndef CEIL_DIV +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) +#endif +namespace optiling { +const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000; + +const static int64_t GROUP_SELECT_MODE_MAX = 0; +const static int64_t GROUP_SELECT_MODE_SUM = 1; +const static int64_t RENORM_NO = 0; +const static int64_t RENORM_L1 = 1; +const static int64_t NORM_TYPE_SOFTMAX = 0; +const static int64_t NORM_TYPE_SIGMOID = 1; +const static int64_t OUT_FLAG_FALSE = 0; +const static int64_t OUT_FLAG_TRUE = 1; +const static size_t X_INPUT_DIMS = 2; +const static size_t BIAS_INPUT_DIMS = 1; +const static size_t Y_OUTPUT_DIMS = 2; +const static size_t EXPERT_IDX_OUTPUY_DIMS = 2; +const static size_t OUT_OUTPUT_DIMS = 2; +const static int64_t MAX_EXPERT_COUNT = 2048; + +const static int64_t X_INPUT_INDEX = 0; +const static int64_t BIAS_INPUT_INDEX = 1; +const static int64_t Y_OUTPUT_INDEX = 0; +const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1; +const static int64_t OUT_OUTPUT_INDEX = 2; +const static int64_t K_ATTR_INDEX = 0; +const static int64_t K_GROUP_ATTR_INDEX = 1; +const static int64_t GROUP_COUNT_ATTR_INDEX = 2; +const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3; +const static int64_t RENORM_ATTR_INDEX = 4; +const static int64_t MRGSORT_SIZE = 4; +const static int64_t NORM_TYPE_ATTR_INDEX = 5; +const static int64_t OUT_FLAG_ATTR_INDEX = 6; +const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7; +const static int64_t EPS_ATTR_INDEX = 8; +const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast(16 * 1024 * 1024); // 预留16M空间 + + +class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass { +public: + explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context) + { + Reset(); + } + ~MoeGatingTopKTilingRegbase() override = default; + + void Reset(gert::TilingContext *context) override + { + TilingBaseClass::Reset(context); + Reset(); + } + +protected: + bool IsCapable() override + { + if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) { + return false; + } + return true; + } + + ge::graphStatus GetPlatformInfo() override; + + ge::graphStatus GetShapeAttrsInfo() override; + + ge::graphStatus DoOpTiling() override; + + ge::graphStatus DoLibApiTiling() override; + + uint64_t GetTilingKey() const override; + + ge::graphStatus GetWorkspaceSize() override; + + ge::graphStatus PostTiling() override; + void Reset(); + +private: + ge::graphStatus CheckInputShape(); + ge::graphStatus CheckAttr(); + ge::graphStatus CheckOutShape(); + void CalTmpBufUbSize(); + void SplitRows(); + void Tiling4GatherOutComputeSplitK(); + + const gert::Shape *xShape_ = nullptr; + const gert::Shape *biasShape_ = nullptr; + const gert::Shape *yShape_ = nullptr; + const gert::Shape *expertIdxShape_ = nullptr; + const gert::Shape *outShape_ = nullptr; + + int64_t rows_; + int64_t expertCount_; + int64_t addBias_ = 0; + + int64_t k_; + int64_t kGroup_ = 1; + int64_t groupCount_ = 1; + int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX; + int64_t renorm_ = RENORM_NO; + int64_t normType_ = NORM_TYPE_SOFTMAX; + int64_t outFlag_ = OUT_FLAG_FALSE; + float routedScalingFactor_ = 1.0; + float eps_ = 1e-20f; + + int64_t inputDtypeSize_; + const char *opName_ = ""; + MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_; + platform_ascendc::SocVersion socVersion; +}; + +ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape() +{ + size_t xDimNum = xShape_->GetDimNum(); + OP_CHECK_IF(xDimNum != X_INPUT_DIMS, + OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS), + return ge::GRAPH_FAILED); + + + rows_ = xShape_->GetDim(0); + expertCount_ = xShape_->GetDim(1); + moeGatingTopKTilingData_.set_rowCount(rows_); + moeGatingTopKTilingData_.set_expertCount(expertCount_); + OP_CHECK_IF( + expertCount_ > MAX_EXPERT_COUNT, + OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT), + return ge::GRAPH_FAILED); + + if (biasShape_ != nullptr) { + addBias_ = 1; + size_t biasDimNum = biasShape_->GetDimNum(); + OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS, + OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS), + return ge::GRAPH_FAILED); + OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_, + OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.", + biasShape_->GetDim(0), expertCount_), + return ge::GRAPH_FAILED); + } + moeGatingTopKTilingData_.set_addBias(addBias_); + + OP_CHECK_IF(k_ > expertCount_, + OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr() +{ + OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED); + OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_), + return ge::GRAPH_FAILED); + OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_), + return ge::GRAPH_FAILED); + OP_CHECK_IF(expertCount_ % groupCount_ != 0, + OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_), + return ge::GRAPH_FAILED); + OP_CHECK_IF(kGroup_ > groupCount_, + OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_), + return ge::GRAPH_FAILED); + OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_, + OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.", + kGroup_, k_), + return ge::GRAPH_FAILED); + + if (kGroup_ == groupCount_ || groupCount_ == expertCount_) { + kGroup_ = 1; + groupCount_ = 1; + } + moeGatingTopKTilingData_.set_kGroup(kGroup_); + moeGatingTopKTilingData_.set_groupCount(groupCount_); + int64_t groupExpertCount = expertCount_ / groupCount_; + int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L); + moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_); + moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign); + + OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT, + OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.", + groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(kGroup_ * groupExpertCount < k_, + OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.", + kGroup_ * groupExpertCount, k_), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(groupExpertCount < 1, + OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount), + return ge::GRAPH_FAILED); + OP_CHECK_IF( + groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX, + OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_, + GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX), + return ge::GRAPH_FAILED); + OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2, + OP_LOGE(context_, + "group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.", + groupExpertCount, groupSelectMode_), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(renorm_ != RENORM_NO, + OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO), + return ge::GRAPH_FAILED); + + OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID, + OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_, + NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo() +{ + opName_ = context_->GetNodeName(); + + auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr); + xShape_ = &xShapePtr->GetStorageShape(); + auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX); + biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape(); + + + auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr); + yShape_ = &yShapePtr->GetStorageShape(); + auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr); + expertIdxShape_ = &expertIdxPtr->GetStorageShape(); + auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX); + if (outPtr != nullptr) { + outShape_ = &outPtr->GetStorageShape(); + } + + auto x = context_->GetInputDesc(X_INPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, x); + auto xDtype = x->GetDataType(); + OP_CHECK_IF( + (xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16), + OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.", + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + + if (biasShapePtr != nullptr) { + auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType(); + OP_CHECK_IF((biasDtype != xDtype), + OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.", + ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(), + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + } + + auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc); + auto yDtype = yDesc->GetDataType(); + OP_CHECK_IF((yDtype != xDtype), + OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.", + ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(), + ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), + return ge::GRAPH_FAILED); + + auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc); + auto expertIdDtype = expertIdDesc->GetDataType(); + OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32), + OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.", + ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()), + return ge::GRAPH_FAILED); + + + auto attrs = context_->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context_, attrs); + + const int64_t *kPtr = attrs->GetAttrPointer(K_ATTR_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr); + k_ = *kPtr; + moeGatingTopKTilingData_.set_k(k_); + OP_LOGI(context_, "Attr k is: %ld ", k_); + + const int64_t *kGroupPtr = attrs->GetAttrPointer(K_GROUP_ATTR_INDEX); + if (kGroupPtr != nullptr) { + kGroup_ = *kGroupPtr; + } + OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_); + + const int64_t *groupCountPtr = attrs->GetAttrPointer(GROUP_COUNT_ATTR_INDEX); + if (groupCountPtr != nullptr) { + groupCount_ = *groupCountPtr; + } + OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_); + + const int64_t *groupSelectModePtr = attrs->GetAttrPointer(GROUP_SELECT_MODE_ATTR_INDEX); + if (groupSelectModePtr != nullptr) { + groupSelectMode_ = *groupSelectModePtr; + } + moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_); + OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_); + + const int64_t *renormPtr = attrs->GetAttrPointer(RENORM_ATTR_INDEX); + if (renormPtr != nullptr) { + renorm_ = *renormPtr; + } + moeGatingTopKTilingData_.set_renorm(renorm_); + OP_LOGI(context_, "Attr renorm is: %ld ", renorm_); + + const int64_t *normTypePtr = attrs->GetAttrPointer(NORM_TYPE_ATTR_INDEX); + if (normTypePtr != nullptr) { + normType_ = *normTypePtr; + } + moeGatingTopKTilingData_.set_normType(normType_); + OP_LOGI(context_, "Attr norm_type is: %ld ", normType_); + + const bool *outFlagPtr = attrs->GetAttrPointer(OUT_FLAG_ATTR_INDEX); + if (outFlagPtr != nullptr) { + outFlag_ = (*outFlagPtr) ? 1 : 0; + } + moeGatingTopKTilingData_.set_outFlag(outFlag_); + OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_); + + const float *routedScalingFactorPtr = attrs->GetAttrPointer(ROUTED_SCALING_FACTOR_ATTR_INDEX); + if (routedScalingFactorPtr != nullptr) { + routedScalingFactor_ = *routedScalingFactorPtr; + } + moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_); + OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_); + + const float *epsPtr = attrs->GetAttrPointer(EPS_ATTR_INDEX); + if (epsPtr != nullptr) { + eps_ = *epsPtr; + } + moeGatingTopKTilingData_.set_eps(eps_); + OP_LOGI(context_, "Attr eps is: %f ", eps_); + + auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX); + if (outFlag_ && outDesc != nullptr) { + auto outDtype = outDesc->GetDataType(); + OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT), + OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.", + ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()), + return ge::GRAPH_FAILED); + } + + inputDtypeSize_ = static_cast(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType())); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo() +{ + auto platformInfo = context_->GetPlatformInfo(); + OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv(); + socVersion = ascendcPlatform.GetSocVersion(); + uint64_t ubSizePlatForm; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); + aicoreParams_.ubSize = ubSizePlatForm; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape() +{ + OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(), + xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.", + expertIdxShape_->GetDimNum(), xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + if (outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()), + OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.", + outShape_->GetDimNum(), xShape_->GetDimNum()), + return ge::GRAPH_FAILED); + } + + OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0), + xShape_->GetDim(0)), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.", + expertIdxShape_->GetDim(0), xShape_->GetDim(0)), + return ge::GRAPH_FAILED); + if (outFlag_ && outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)), + OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.", + outShape_->GetDim(0), outShape_->GetDim(0)), + return ge::GRAPH_FAILED); + } + + OP_CHECK_IF((yShape_->GetDim(1) != k_), + OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_), + return ge::GRAPH_FAILED); + OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_), + OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_), + return ge::GRAPH_FAILED); + if (outFlag_ && outShape_ != nullptr) { + OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)), + OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1), + xShape_->GetDim(1)), + return ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() { + std::vector shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()}; + ge::Shape softmaxShape(shape_vec); + + uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true); + AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData); +} + +void MoeGatingTopKTilingRegbase::SplitRows() +{ + int64_t perCoreRows = CEIL_DIV(rows_, static_cast(aicoreParams_.blockDim)); + int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows); + if (perCoreRows == 0) { + OP_LOGE(context_, "perCoreRows can't be 0."); + return; + } + int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows; + moeGatingTopKTilingData_.set_needCoreNum(needCoreNum); + moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows); + moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows); + + int64_t vmsCount = 0; + if (kGroup_ > MRGSORT_SIZE) { + int64_t index = MRGSORT_SIZE; + while (index < kGroup_) { + index = index * MRGSORT_SIZE; + vmsCount++; + } + } + moeGatingTopKTilingData_.set_vmsCount(vmsCount); +} + +ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling() +{ + auto ret = CheckInputShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckAttr(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckOutShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + CalTmpBufUbSize(); + SplitRows(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling() +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize() +{ + + workspaceSize_ = DEFAULT_WORKSPACE_SIZE; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling() +{ + context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum()); + size_t *currentWorkspace = context_->GetWorkspaceSizes(1); + currentWorkspace[0] = workspaceSize_; + moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), + context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize()); + return ge::GRAPH_SUCCESS; +} + +uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const +{ + return MOE_GATING_TOP_K_REGBASE_TILING_KEY; +} + +void MoeGatingTopKTilingRegbase::Reset() +{ + opName_ = nullptr; + return; +} + +REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000); +} // namespace optiling diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp new file mode 100644 index 00000000..db9ce60d --- /dev/null +++ b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp @@ -0,0 +1,38 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_gating_top_k_tiling_base.cpp + * \brief + */ +#include "moe_gating_top_k_tiling.h" +#include "register/op_def_registry.h" +#include "../tiling_base/tiling_base.h" +#include "../tiling_base/tiling_templates_registry.h" +#include "error_log.h" +#include "kernel_tiling/kernel_tiling.h" + +namespace optiling { +static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context) +{ + return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context); +} + +static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(MoeGatingTopK) + .Tiling(TilingForMoeGatingTopK) + .TilingParse(TilingPrepareForMoeGatingTopK); + +} // namespace optiling \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/common.h b/csrc/moe_gating_top_k/op_kernel/common.h new file mode 100644 index 00000000..0847ee2b --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/common.h @@ -0,0 +1,89 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file common.h + * \brief + */ +#ifndef MOE_GATING_TOP_K_COMMON_H +#define MOE_GATING_TOP_K_COMMON_H + +#include "kernel_operator.h" + +namespace MoeGatingTopK { +using namespace AscendC; +const float MIN_FP32 = *(float *)(&F32_NEG_INF); +constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040 +constexpr int64_t ONE_REPEAT_SORT_NUM = 32; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t REPEAT_BYTES = 256; +constexpr int64_t REPEAT_BLOCKS = 8; + +constexpr int32_t CONSTANT_TWO = 2; +constexpr int32_t CONSTANT_THREE = 3; +constexpr int32_t CONSTANT_FOUR = 4; +constexpr int32_t CONSTANT_EIGHT = 8; + +constexpr int64_t MERGE_LIST_TWO = 2; +constexpr int64_t MERGE_LIST_THREE = 3; +constexpr int64_t MERGE_LIST_FOUR = 4; + +constexpr int64_t MERGE_LIST_IDX_TWO = 2; +constexpr int64_t MERGE_LIST_IDX_THREE = 3; + +constexpr int64_t NORM_TYPE_SOFTMAX = 0; +constexpr int64_t NORM_TYPE_SIGMOID = 1; + +__aicore__ inline int64_t Ceil(int64_t a, int64_t b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) +{ + if (bytes == 0) { + return 0; + } + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes; +} + +__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) +{ + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES; +} + +template +__aicore__ inline T Min(T a, T b) +{ + return a > b ? b : a; +} + +template +__aicore__ inline T Max(T a, T b) +{ + return a < b ? b : a; +} + +template +__aicore__ inline T1 CeilDiv(T1 x, T2 y) +{ + if (y != 0 && x != 0) { + const T1 quotient = x / y; + return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient; + } + + return x; +} + +} // namespace MoeGatingTopK +#endif // MOE_GATING_TOP_K_COMMON_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/error_log.h b/csrc/moe_gating_top_k/op_kernel/error_log.h new file mode 100644 index 00000000..f48985a4 --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/error_log.h @@ -0,0 +1,55 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_CHECK_IF(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) + + + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if ((ptr) == nullptr) { \ + OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp new file mode 100644 index 00000000..4390fee4 --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp @@ -0,0 +1,63 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k.cpp + * \brief + */ + +#include "moe_gating_top_k_e_k_fullload.h" +#include "moe_gating_top_k_without_group.h" +#include "moe_gating_top_k_generalized.h" +#include "error_log.h" + +#define TILING_KEY_PER_GROUP_COUNT_32 0 +#define TILING_KEY_WITHOUT_GROUP 1 +#define TILING_KEY_GENERALIZED 2 + +using namespace AscendC; +using namespace MoeGatingTopK; +extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, + GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) +{ + + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); + if (g_coreType == AIC) { + return; + } + + + GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling); + if (workspace == nullptr) { + return; + } + + GM_ADDR userWS = GetUserWorkspace(workspace); + if (userWS == nullptr) { + return; + } + + const MoeGatingTopKTilingData *__restrict t = &tilingData; + TPipe tPipe; + if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) { + MoeGatingTopKEKFullload op; + op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); + op.Process(); + } else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) { + MoeGatingTopKWithoutGroup op; + op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); + op.Process(); + } else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) { + MoeGatingTopKGenerlized op; + op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); + op.Process(); + } + +} diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp new file mode 100644 index 00000000..5abaf2ba --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp @@ -0,0 +1,46 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_apt.cpp + * \brief + */ + +#include "arch35/moe_gating_top_k_regbase.h" +using namespace AscendC; +using namespace MoeGatingTopK; + +#define TILING_KEY_REGBASE 10000 + +extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, + GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) +{ + if (g_coreType == AIC) { + return; + } + + if (workspace == nullptr) { + return; + } + + GM_ADDR userWS = GetUserWorkspace(workspace); + if (userWS == nullptr) { + return; + } + + GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling); + const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in; + TPipe tPipe; + if (TILING_KEY_IS(TILING_KEY_REGBASE)) { + MoeGatingTopKRegbase op; + op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe); + op.Process(); + } +} diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h new file mode 100644 index 00000000..6891cc0e --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h @@ -0,0 +1,404 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_e_k_fullload.h + * \brief + */ +#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H +#define MOE_GATING_TOP_K_E_K_FULLLOAD_H +#include "kernel_operator.h" +#include "common.h" +namespace MoeGatingTopK { +using namespace AscendC; + +template +class MoeGatingTopKEKFullload { +public: + __aicore__ inline MoeGatingTopKEKFullload(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInBias(); + __aicore__ inline void CopyInX(int64_t progress); + __aicore__ inline void ComputeX(); + __aicore__ inline void SortInGroup(); + __aicore__ inline void SelectTopKGroupIndex(); + __aicore__ inline void SelectTopKExpertIdx(); + __aicore__ inline void SelectTopKExpertScore(); + __aicore__ inline void CopyOut(int64_t progress); + +private: + TPipe *pipe_; + TQue xInQueue_; + TBuf biasInQueue_; + TQue yOutQueue_; + TQue expertIdxOutQueue_; + TQue outOutQueue_; + + TQue xBiasQueue_; + TQue xSigmoidQueue_; + TQue sigmoidTmpQueue_; + TQue sortedInGroupQueue_; + TQue sortedGroupQueue_; + TBuf calcTmpBuffer_; + + GlobalTensor xGm_; + GlobalTensor biasGm_; + GlobalTensor yGm_; + GlobalTensor expertIdxGm_; + GlobalTensor outGm_; + + int64_t blockIdx_; + int64_t perCoreRowCount_; + int64_t curCoreRowCount_; + int64_t expertCount_; + bool addBias_; + int64_t k_; + int64_t kGroup_; + int64_t groupCount_; + int64_t groupSelectMode_; + int64_t renorm_; + int64_t normType_; + int64_t outFlag_; + float routedScalingFactor_; + float eps_; + + int64_t expertCountAlign_; + int64_t kAlign_; + int64_t perGroupExpertCount_; + + const MoeGatingTopKTilingData *tilingData_; +}; + +template +__aicore__ inline void MoeGatingTopKEKFullload::CopyInBias() +{ + LocalTensor biasTensor = biasInQueue_.Get(); + DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if constexpr (IsSameType::value) { + DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + } else { + DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, expertCount_); + } +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::CopyInX(int64_t row) +{ + LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if constexpr (IsSameType::value) { + DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); + } else { + DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, + dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, + expertCount_); + } + + xInQueue_.EnQue(xInLocalTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::ComputeX() +{ + LocalTensor xSigmoidTensor = xSigmoidQueue_.AllocTensor(); + LocalTensor xInLocalTensor = xInQueue_.DeQue(); + LocalTensor xBiasTensor = xBiasQueue_.AllocTensor(); + LocalTensor biasTensor = biasInQueue_.Get(); + LocalTensor sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor(); // 临时空间可以复用 + Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_); + PipeBarrier(); + if (addBias_) { + Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_); + } else { + Adds(xBiasTensor, xSigmoidTensor, static_cast(0), expertCount_); + } + + xSigmoidQueue_.EnQue(xSigmoidTensor); + xBiasQueue_.EnQue(xBiasTensor); + xInQueue_.FreeTensor(xInLocalTensor); + sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::SortInGroup() +{ + LocalTensor xBiasTensor = xBiasQueue_.DeQue(); + LocalTensor sortedInGroupTensor = sortedInGroupQueue_.AllocTensor(); // 组内排序的结果, 后续归并需要 + LocalTensor indexTensor = calcTmpBuffer_.Get(); // 用于存储排序时的索引 + ArithProgression(indexTensor.ReinterpretCast(), 0, 1, expertCount_); // 生成组索引0 1 2 ...... + PipeBarrier(); + Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序 + sortedInGroupQueue_.EnQue(sortedInGroupTensor); + xBiasQueue_.FreeTensor(xBiasTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKGroupIndex() +{ + LocalTensor sortedInGroupTensor = sortedInGroupQueue_.DeQue(); + LocalTensor indexTensor = calcTmpBuffer_.Get(); + LocalTensor top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor(); // 这个临时空间可以复用 + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + + indexTensor.SetValue(0, static_cast(5)); // b0101 + indexTensor.SetValue(1, static_cast(0)); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + + uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = 8; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = 8; + gatherMaskParams.src1RepeatStride = 0; + GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast(64), + gatherMaskParams, rsvdCnt); + PipeBarrier(); + LocalTensor groupTop2SumTensor = top2ValueInGroupTensor; + PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1, + 1); // 计算每个组内最大的两个数之和 + PipeBarrier(); + + LocalTensor groupIndexTensor = indexTensor; + ArithProgression(groupIndexTensor.ReinterpretCast(), 0, 1, groupCount_); // 生成组索引 + PipeBarrier(); + // 用最小值补到32个数 + int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_; + if (duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX << groupCount_; + uint64_t mask[2] = {mask0, 0}; + Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8); + PipeBarrier(); + } + // 排序,将kgroup选出来 + LocalTensor sortedGroupTensor = sortedGroupQueue_.AllocTensor(); + Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1); + + PipeBarrier(); + LocalTensor sortedGroupIndexTensor = indexTensor.ReinterpretCast(); + // 提取组序号 + uint8_t src1Pattern = 2; // 内置固定模式 + GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast(), src1Pattern, false, + static_cast(0), {1, 1, 0, 0}, rsvdCnt); + + // 需要将组排序(这里是降序,所以下mrgsor的时候反着取,3、2、1、0) + Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_); + PipeBarrier(); + duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_; + if (duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX << kGroup_; + uint64_t mask[2] = {mask0, 0}; + Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8); + PipeBarrier(); + } + Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast(), 1); + PipeBarrier(); + src1Pattern = 1; + GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast(0), {1, 1, 0, 0}, + rsvdCnt); + PipeBarrier(); + Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_); + + sortedGroupQueue_.FreeTensor(sortedGroupTensor); + sortedInGroupQueue_.EnQue(sortedInGroupTensor); + sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKExpertIdx() +{ + LocalTensor expertIdxTensor = expertIdxOutQueue_.AllocTensor(); + LocalTensor topKGroupIndexTensor = calcTmpBuffer_.Get(); + LocalTensor sortedInGroupTensor = sortedInGroupQueue_.DeQue(); + LocalTensor sortedExpertTensor = xInQueue_.AllocTensor(); + AscendC::MrgSort4Info params; + params.elementLengths[0] = k_; + params.elementLengths[1] = k_; + params.elementLengths[2] = k_; + params.elementLengths[3] = k_; + params.ifExhaustedSuspension = true; + params.validBit = 0b1111; + params.repeatTimes = 1; + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2; + int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2; + int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2; + int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2; + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + AscendC::MrgSortSrcList srcList; + srcList.src1 = sortedInGroupTensor[listOffset1]; + srcList.src2 = sortedInGroupTensor[listOffset2]; + srcList.src3 = sortedInGroupTensor[listOffset3]; + srcList.src4 = sortedInGroupTensor[listOffset4]; + MrgSort(sortedExpertTensor, srcList, params); + PipeBarrier(); + uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 + uint8_t src1Pattern = 2; // 内置固定模式 + GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast(), src1Pattern, false, + static_cast(0), {1, 1, 0, 0}, rsvdCnt); + xInQueue_.FreeTensor(sortedExpertTensor); + expertIdxOutQueue_.EnQue(expertIdxTensor); + sortedInGroupQueue_.FreeTensor(sortedInGroupTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKExpertScore() +{ + LocalTensor expertIdxTensor = expertIdxOutQueue_.DeQue(); + LocalTensor expertByteIdxTensor = calcTmpBuffer_.Get(); + LocalTensor xSigmoidTensor = xSigmoidQueue_.DeQue(); + LocalTensor yTensor = yOutQueue_.AllocTensor(); + LocalTensor yOutTensor; + if constexpr (!IsSameType::value) { + yOutTensor = yTensor.template ReinterpretCast()[kAlign_]; + } else { + yOutTensor = yTensor; + } + Muls(expertByteIdxTensor, expertIdxTensor, static_cast(sizeof(float)), k_); + PipeBarrier(); + Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast(), + static_cast(0), k_); + + LocalTensor calTensor = calcTmpBuffer_.Get(); + PipeBarrier(); + ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float sumValue = calTensor.GetValue(0) + eps_; + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Duplicate(calTensor, sumValue, k_); + PipeBarrier(); + Div(yOutTensor, yOutTensor, calTensor, k_); + PipeBarrier(); + Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_); + + if constexpr (!IsSameType::value) { + PipeBarrier(); + Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_); + } + + xSigmoidQueue_.EnQue(xSigmoidTensor); + expertIdxOutQueue_.EnQue(expertIdxTensor); + yOutQueue_.EnQue(yTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::CopyOut(int64_t row) +{ + LocalTensor yOutTensor = yOutQueue_.DeQue(); + LocalTensor expertIdxTensor = expertIdxOutQueue_.DeQue(); + LocalTensor xSigmoidTensor = xSigmoidQueue_.DeQue(); + DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; + DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); + dataCopyParams.blockLen = k_ * sizeof(int32_t); + DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams); + xSigmoidQueue_.FreeTensor(xSigmoidTensor); + expertIdxOutQueue_.FreeTensor(expertIdxTensor); + yOutQueue_.FreeTensor(yOutTensor); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, + GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) +{ + tilingData_ = tilingData; + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + perCoreRowCount_ = tilingData_->perCoreRowCount; + if (blockIdx_ == GetBlockNum() - 1) { + curCoreRowCount_ = tilingData_->lastCoreRowCount; + } else { + curCoreRowCount_ = tilingData_->perCoreRowCount; + } + expertCount_ = tilingData_->expertCount; + addBias_ = tilingData_->addBias == 1; + k_ = tilingData_->k; + kGroup_ = tilingData_->kGroup; + groupCount_ = tilingData_->groupCount; + perGroupExpertCount_ = tilingData_->perGroupExpertCount; + routedScalingFactor_ = tilingData_->routedScalingFactor; + eps_ = tilingData_->eps; + + expertCountAlign_ = Align(expertCount_, sizeof(float)); + kAlign_ = Align(expertCount_, sizeof(float)); + + // init input gm buf + xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); + + // init output gm buf + yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); + outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + + // init que + pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + + pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float))); + pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float))); + + pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t))); + pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float))); + + pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float))); + pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2); + pipe_->InitBuffer(sortedGroupQueue_, 2, + (groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM * + sizeof(float) * 2); + + pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize); +} + +template +__aicore__ inline void MoeGatingTopKEKFullload::Process() +{ + CopyInBias(); + for (int64_t row = 0; row < curCoreRowCount_; row++) { + CopyInX(row); + ComputeX(); + SortInGroup(); + SelectTopKGroupIndex(); + SelectTopKExpertIdx(); + SelectTopKExpertScore(); + CopyOut(row); + } +} +} // namespace MoeGatingTopK +#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h new file mode 100644 index 00000000..b1cc14e6 --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h @@ -0,0 +1,669 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_generalized.h + * \brief + */ +#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H +#define MOE_GATING_TOP_K_E_K_GENERALIZED_H +#include "kernel_operator.h" +#include "common.h" +#include "kernel_utils.h" +namespace MoeGatingTopK { +using namespace AscendC; + +template +class MoeGatingTopKGenerlized { +public: + __aicore__ inline MoeGatingTopKGenerlized(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInBiasAndInitExpertId(); + __aicore__ inline void CopyInX(int64_t progress); + __aicore__ inline void ComputeX(); + __aicore__ inline void CopuOutXNorm(int64_t row); + __aicore__ inline void SortInGroup(); + __aicore__ inline void SelectTopKGroupIndex(); + __aicore__ inline void SelectTopKExpertIdx(); + __aicore__ inline void SelectTopKExpertScore(); + __aicore__ inline void CumputeActualTopKExpertId(); + __aicore__ inline void CopyOut(int64_t row); + +private: + TPipe *pipe_; + TQue xInQueue_; + TQue yOutQueue_; + TQue expertIdxOutQueue_; + TQue outOutQueue_; + + TBuf biasBuf_; // Store input bias + TBuf expertIdBuf_; // Expert ID + TBuf xNormWithBiasBuf_; // Store value after adding bias + TBuf xNormBuf_; // Store value after computing sigmoid or softmax + TBuf sortedInGroupBuf_; // Store sorted results within groups + TBuf topKExpertIdBuf_; + TBuf sortedGroupIndexBuf_; + TBuf calcTmpBuf_; + + GlobalTensor xGm_; + GlobalTensor biasGm_; + GlobalTensor yGm_; + GlobalTensor expertIdxGm_; + GlobalTensor outGm_; + + int64_t blockIdx_ = 0; + int64_t perCoreRowCount_ = 0; + int64_t curCoreRowCount_ = 0; + int64_t expertCount_ = 0; + bool addBias_ = false; + int64_t k_ = 0; + int64_t kGroup_ = 0; + int64_t groupCount_ = 0; + int64_t groupCountAlign_ = 0; + int64_t perGroupExpertCount_ = 0; + int64_t perGroupExpertCountAlign_ = 0; + int64_t groupSelectMode_ = 0; + int64_t renorm_ = 0; + int64_t normType_ = 0; + int64_t outFlag_ = 0; + + int64_t expertCountAlign_ = 0; + int64_t kAlign_ = 0; + bool isAlign_ = false; + + const MoeGatingTopKTilingData *tilingData_; +}; + +template +__aicore__ inline void MoeGatingTopKGenerlized::CopyInBiasAndInitExpertId() +{ + LocalTensor biasTensor = biasBuf_.Get(); + LocalTensor expertIdTensor = expertIdBuf_.Get(); + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = groupCount_; + dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES; + + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if (addBias_) { + if constexpr (IsSameType::value) { + DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + } else { + DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, + expertCountAlign_); + PipeBarrier(); + } + + if (!isAlign_) { + int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM; + int duplicateIndex = perGroupExpertCount_ - duplicateNum; + if (duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(biasTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1, + perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES); + } + } + } + ArithProgression(expertIdTensor, static_cast(0), static_cast(1), expertCountAlign_); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::CopyInX(int64_t row) +{ + LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = groupCount_; + dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES; + + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if constexpr (IsSameType::value) { + DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); + } else { + DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, + dataCopyPadParams); + } + xInQueue_.EnQue(xInLocalTensor); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::ComputeX() +{ + LocalTensor xNormTensor = xNormBuf_.Get(); + LocalTensor xInLocalTensor = xInQueue_.DeQue(); + LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); + LocalTensor biasTensor = biasBuf_.Get(); + + if constexpr (!IsSameType::value) { + Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, + expertCountAlign_); + PipeBarrier(); + } + + int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM; + int duplicateIndex = perGroupExpertCount_ - duplicateNum; + if (!isAlign_ && duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(xInLocalTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1, + (perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES); + PipeBarrier(); + } + if (normType_ == 1) { // sigmoid + LocalTensor calcNormTmpTensor = calcTmpBuf_.Get(); + Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_); + PipeBarrier(); + } + else if (normType_ == 0) { // softmax + LocalTensor reduceValueTensor = calcTmpBuf_.Get(); + LocalTensor calcTmp = calcTmpBuf_.Get()[BLOCK_BYTES]; + ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float maxValue = reduceValueTensor.GetValue(0); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_); + PipeBarrier(); + Exp(xNormTensor, xNormTensor, expertCountAlign_); + PipeBarrier(); + ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_); + eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float sumValue = reduceValueTensor.GetValue(0); + eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_); + PipeBarrier(); + } + if (addBias_) { + Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_); + } else { + DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_); + } + + if (!isAlign_ && duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + PipeBarrier(); + Duplicate(xNormWithBiasTensor.ReinterpretCast()[duplicateIndex], + FLOAT32_NEG_INF, // MIN_FP32, + mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES); + } + xInQueue_.FreeTensor(xInLocalTensor); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::CopuOutXNorm(int64_t row) +{ + LocalTensor outOutTensor = outOutQueue_.AllocTensor(); + LocalTensor xNormTensor = xNormBuf_.Get(); + DataCopy(outOutTensor, xNormTensor, expertCountAlign_); + outOutQueue_.EnQue(outOutTensor); + outOutTensor = outOutQueue_.DeQue(); + DataCopyExtParams dataCopyParams{ + static_cast(groupCount_), static_cast(perGroupExpertCount_ * sizeof(float)), + static_cast((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0}; + DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams); + outOutQueue_.FreeTensor(outOutTensor); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::SortInGroup() +{ + LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); + LocalTensor expertIdTensor = expertIdBuf_.Get(); + LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); + LocalTensor tmpLocal = calcTmpBuf_.Get(); + if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) { + PipeBarrier(); + Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_); + } else { + for (int64_t group = 0; group < groupCount_; group++) { + PipeBarrier(); + Sort(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO], + xNormWithBiasTensor[group * perGroupExpertCountAlign_], + expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal, + perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM); + } + } +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKGroupIndex() +{ + LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); + LocalTensor valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, 0); + LocalTensor maskTensor = + calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float)); + LocalTensor topValueInGroupTensor = + calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float)); + LocalTensor groupIndex = + calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float)); + LocalTensor sortedTopValue = + calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float)); + LocalTensor sortTmp = + calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float)); + + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + + uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering + PipeBarrier(); + if (groupSelectMode_ == 1) { // top2 sum + // Extract the first two elements of each group + maskTensor.SetValue(0, static_cast(5)); // b0101 + maskTensor.SetValue(1, static_cast(0)); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = groupCount_; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = + Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES); + gatherMaskParams.src1RepeatStride = 0; + GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true, + static_cast(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt); + PipeBarrier(); + + // Calculate the sum of the first two numbers in each group + PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor, + Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1, + CONSTANT_EIGHT); // Calculate the sum of the two largest numbers in each group + } else { + maskTensor.SetValue(0, static_cast(1)); // b0101 + maskTensor.SetValue(1, static_cast(0)); + + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = groupCount_; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32); + gatherMaskParams.src1RepeatStride = 0; + GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true, + static_cast(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt); + } + + PipeBarrier(); + // Generate group indices + ArithProgression(groupIndex.ReinterpretCast(), static_cast(0), static_cast(1), + groupCount_); // Generate group indices + PipeBarrier(); + + int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM; + int duplicateIndex = groupCount_ - duplicateNum; + if (duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(topValueInGroupTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, + REPEAT_BLOCKS); + PipeBarrier(); + } + PipeBarrier(); + + // Sort + Sort(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32)); + PipeBarrier(); + + // Extract group indices + uint8_t src1Pattern = 2; // Built-in fixed pattern + GatherMask(groupIndex, sortedTopValue.template ReinterpretCast(), src1Pattern, false, + static_cast(0), + {1, static_cast(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt); + PipeBarrier(); + duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + duplicateIndex = kGroup_ - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + PipeBarrier(); + Duplicate(groupIndex.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS); + } + + // Sort the selected group indices in descending order + LocalTensor sortedGroupIndex = sortedGroupIndexBuf_.Get(); + PipeBarrier(); + Sort(sortedGroupIndex, groupIndex.ReinterpretCast(), groupIndex, sortTmp, Ceil(kGroup_, 32)); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKExpertIdx() +{ + LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); + LocalTensor sortedGroupIndex = sortedGroupIndexBuf_.Get(); + LocalTensor topKExpertId = topKExpertIdBuf_.Get(); + LocalTensor mrgSort0Tensor = calcTmpBuf_.Get(); + + uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0}; + uint16_t lenArr[CONSTANT_FOUR] = { + static_cast(perGroupExpertCount_), static_cast(perGroupExpertCount_), + static_cast(perGroupExpertCount_), static_cast(perGroupExpertCount_)}; + MrgSort4Info params{lenArr, false, 0b1111, 1}; + MrgSortSrcList srcList; + + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + + for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) { + int64_t mrgLen = Min(i + 1, CONSTANT_FOUR); + if (mrgLen > 1) { + if (mrgLen == MERGE_LIST_FOUR) { + offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; + offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; + offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2; + offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2; + } else if (mrgLen == MERGE_LIST_THREE) { + offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; + offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; + offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2; + offset[3] = 0; + params.elementLengths[3] = 0; + params.validBit = 0b111; + } else { + offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; + offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; + offset[2] = 0; + offset[3] = 0; + params.elementLengths[2] = 0; + params.elementLengths[3] = 0; + params.validBit = 0b11; + } + + srcList.src1 = sortedInGroupTensor[offset[0]]; + srcList.src2 = sortedInGroupTensor[offset[1]]; + srcList.src3 = sortedInGroupTensor[offset[2]]; + srcList.src4 = sortedInGroupTensor[offset[3]]; + + PipeBarrier(); + MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params); + } else { + offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; + PipeBarrier(); + DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]], + perGroupExpertCountAlign_ * 2); + } + } + int32_t baseLoop = 4; + LocalTensor srcTensor = mrgSort0Tensor; + LocalTensor dstTensor = mrgSort0Tensor; + for (int i = 0; i < tilingData_->vmsCount; i++) { + if (i % 2 == 0) { + srcTensor = mrgSort0Tensor; + dstTensor = sortedInGroupTensor; + } else { + srcTensor = sortedInGroupTensor; + dstTensor = mrgSort0Tensor; + } + + int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR; + int32_t quotient = kGroup_ / nextBaseRow; + int32_t remainder = kGroup_ - quotient * nextBaseRow; + if (quotient > 0) { + MrgSort4Info params; + MrgSortSrcList srcList; + params.ifExhaustedSuspension = false; + params.elementLengths[0] = perGroupExpertCount_ * baseLoop; + params.elementLengths[1] = perGroupExpertCount_ * baseLoop; + params.elementLengths[2] = perGroupExpertCount_ * baseLoop; + params.elementLengths[3] = perGroupExpertCount_ * baseLoop; + params.validBit = 0b1111; + params.repeatTimes = 1; + for (int j = 0; j < quotient; j++) { + srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j]; + srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)]; + srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)]; + srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)]; + PipeBarrier(); + MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params); + } + } + + if (remainder > 0) { + int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2; + int32_t mrgLen = CeilDiv(remainder, baseLoop); + int32_t tailRow = remainder - (mrgLen - 1) * baseLoop; + if (mrgLen > 1) { + MrgSort4Info params; + MrgSortSrcList srcList; + params.repeatTimes = 1; + params.ifExhaustedSuspension = false; + params.elementLengths[0] = perGroupExpertCount_ * baseLoop; + params.elementLengths[1] = perGroupExpertCount_ * baseLoop; + params.elementLengths[2] = perGroupExpertCount_ * baseLoop; + params.elementLengths[3] = perGroupExpertCount_ * baseLoop; + srcList.src1 = srcTensor[baseOffset]; + srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2]; + if (mrgLen == MERGE_LIST_FOUR) { + srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2]; + srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3]; + params.elementLengths[3] = perGroupExpertCount_ * tailRow; + params.validBit = 0b1111; + } else if (mrgLen == MERGE_LIST_THREE) { + srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2]; + params.elementLengths[2] = perGroupExpertCount_ * tailRow; + params.elementLengths[3] = 0; + params.validBit = 0b111; + } else { + params.elementLengths[1] = perGroupExpertCount_ * tailRow; + params.elementLengths[2] = 0; + params.elementLengths[3] = 0; + params.validBit = 0b11; + } + PipeBarrier(); + MrgSort(dstTensor[baseOffset], srcList, params); + } else { + PipeBarrier(); + DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2); + } + } + baseLoop = nextBaseRow; + } + + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES); + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS; + gatherMaskParams.src1RepeatStride = 0; + + uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering + uint8_t src1Pattern = 2; // Built-in fixed pattern + PipeBarrier(); + GatherMask(topKExpertId, dstTensor.template ReinterpretCast(), src1Pattern, false, + static_cast(0), gatherMaskParams, rsvdCnt); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKExpertScore() +{ + LocalTensor xNormTensor = xNormBuf_.Get(); + LocalTensor yOutTensor = yOutQueue_.AllocTensor(); + LocalTensor topKExpertId = topKExpertIdBuf_.Get(); + LocalTensor topKExpertIdWithByte = calcTmpBuf_.Get(); + PipeBarrier(); + Muls(topKExpertIdWithByte, topKExpertId, static_cast(sizeof(float)), k_); + PipeBarrier(); + Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast(), static_cast(0), + k_); + bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm + (normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm + if (needRenorm) { + LocalTensor maxValueTensor = calcTmpBuf_.Get(); + LocalTensor tmpTensor = calcTmpBuf_.Get()[32]; + PipeBarrier(); + ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps; + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Duplicate(tmpTensor, sumValue, k_); + PipeBarrier(); + Div(yOutTensor, yOutTensor, tmpTensor, k_); + } + PipeBarrier(); + Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_); + + if constexpr (!IsSameType::value) { + PipeBarrier(); + Cast(yOutTensor.ReinterpretCast(), yOutTensor, RoundMode::CAST_RINT, k_); + } + + yOutQueue_.EnQue(yOutTensor); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::CumputeActualTopKExpertId() +{ + LocalTensor expertIdxOut = expertIdxOutQueue_.AllocTensor(); + LocalTensor topKExpertId = topKExpertIdBuf_.Get(); + LocalTensor topKExpertIdFp32 = calcTmpBuf_.Get(); + + PipeBarrier(); + Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_); + PipeBarrier(); + Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_); + PipeBarrier(); + Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_); + PipeBarrier(); + Muls(expertIdxOut, expertIdxOut, static_cast(perGroupExpertCountAlign_ - perGroupExpertCount_), k_); + PipeBarrier(); + Sub(expertIdxOut, topKExpertId, expertIdxOut, k_); + expertIdxOutQueue_.EnQue(expertIdxOut); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::CopyOut(int64_t row) +{ + LocalTensor yOutTensor = yOutQueue_.DeQue(); + LocalTensor expertIdxOut = expertIdxOutQueue_.DeQue(); + DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; + DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); + dataCopyParams.blockLen = k_ * sizeof(int32_t); + DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams); + yOutQueue_.FreeTensor(yOutTensor); + expertIdxOutQueue_.FreeTensor(expertIdxOut); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, + GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) +{ + tilingData_ = tilingData; + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + perCoreRowCount_ = tilingData_->perCoreRowCount; + if (blockIdx_ == GetBlockNum() - 1) { + curCoreRowCount_ = tilingData_->lastCoreRowCount; + } else { + curCoreRowCount_ = tilingData_->perCoreRowCount; + } + expertCount_ = tilingData_->expertCount; + addBias_ = tilingData_->addBias == 1; + k_ = tilingData_->k; + kGroup_ = tilingData_->kGroup; + groupCount_ = tilingData_->groupCount; + groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + perGroupExpertCount_ = tilingData_->perGroupExpertCount; + perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign; + renorm_ = tilingData_->renorm; + normType_ = tilingData_->normType; + groupSelectMode_ = tilingData_->groupSelectMode; + + expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float)); + kAlign_ = Align(k_, sizeof(float)); + + isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_; + + // init input gm buf + xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); + + // init output gm buf + yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); + outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + + // init que + pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float)); + pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t)); + pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float)); + + pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t)); + + pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float)); + + pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float)); + pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t))); + + pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO); + pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t)); + pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10); +} + +template +__aicore__ inline void MoeGatingTopKGenerlized::Process() +{ + CopyInBiasAndInitExpertId(); + for (int64_t row = 0; row < curCoreRowCount_; row++) { + CopyInX(row); + ComputeX(); + if (tilingData_->outFlag) { + CopuOutXNorm(row); + } + SortInGroup(); + SelectTopKGroupIndex(); + SelectTopKExpertIdx(); + SelectTopKExpertScore(); + CumputeActualTopKExpertId(); + CopyOut(row); + } +} +} // namespace MoeGatingTopK +#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h new file mode 100644 index 00000000..a28ea3da --- /dev/null +++ b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h @@ -0,0 +1,338 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_gating_top_k_without_group.h + * \brief + */ +#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H +#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H +#include "kernel_operator.h" +#include "common.h" +#include "kernel_utils.h" +namespace MoeGatingTopK { +using namespace AscendC; + +template +class MoeGatingTopKWithoutGroup { +public: + __aicore__ inline MoeGatingTopKWithoutGroup(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInBiasAndInitExpertId(); + __aicore__ inline void CopyInX(int64_t progress); + __aicore__ inline void ComputeX(); + __aicore__ inline void CopuOutXNorm(int64_t row); + __aicore__ inline void SelectTopKExpertIdx(); + __aicore__ inline void SelectTopKExpertScore(); + __aicore__ inline void CopyOut(int64_t row); + +private: + TPipe *pipe_; + TQue xInQueue_; + TQue yOutQueue_; + TQue expertIdxOutQueue_; + TQue outOutQueue_; + + TBuf biasBuf_; // Store input bias + TBuf expertIdBuf_; // Expert ID + TBuf xNormWithBiasBuf_; // Store value after adding bias + TBuf xNormBuf_; // Store value after computing sigmoid or softmax + TBuf topKExpertIdBuf_; + TBuf calcTmpBuf_; + + GlobalTensor xGm_; + GlobalTensor biasGm_; + GlobalTensor yGm_; + GlobalTensor expertIdxGm_; + GlobalTensor outGm_; + + int64_t blockIdx_ = 0; + int64_t perCoreRowCount_ = 0; + int64_t curCoreRowCount_ = 0; + int64_t expertCount_ = 0; + bool addBias_ = false; + bool outFlag_ = false; + int64_t k_ = 0; + int64_t renorm_ = 0; + int64_t normType_ = 0; + int64_t expertCountAlign_ = 0; + const MoeGatingTopKTilingData *tilingData_; +}; + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::CopyInBiasAndInitExpertId() +{ + LocalTensor biasTensor = biasBuf_.Get(); + LocalTensor expertIdTensor = expertIdBuf_.Get(); + DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if (addBias_) { + if constexpr (IsSameType::value) { + DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + } else { + DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, + expertCountAlign_); + PipeBarrier(); + } + } + ArithProgression(expertIdTensor, static_cast(0), static_cast(1), expertCount_); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::CopyInX(int64_t row) +{ + LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; + if constexpr (IsSameType::value) { + DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); + } else { + DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, + dataCopyPadParams); + } + xInQueue_.EnQue(xInLocalTensor); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::ComputeX() +{ + LocalTensor xNormTensor = xNormBuf_.Get(); + LocalTensor xInLocalTensor = xInQueue_.DeQue(); + LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); + LocalTensor biasTensor = biasBuf_.Get(); + + if constexpr (!IsSameType::value) { + Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, + expertCount_); + PipeBarrier(); + } + + if (normType_ == 1) { // sigmoid + LocalTensor calcNormTmpTensor = calcTmpBuf_.Get(); + Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_); + PipeBarrier(); + } else if (normType_ == 0) { // sigmoid + LocalTensor reduceValueTensor = calcTmpBuf_.Get(); + LocalTensor calcTmp = calcTmpBuf_.Get()[8]; + ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float maxValue = reduceValueTensor.GetValue(0); + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_); + PipeBarrier(); + Exp(xNormTensor, xNormTensor, expertCount_); + PipeBarrier(); + ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_); + eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float sumValue = reduceValueTensor.GetValue(0); + eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_); + PipeBarrier(); + } + if (addBias_) { + Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_); + } else { + DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_); + } + + int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM; + int duplicateIndex = expertCount_ - duplicateNum; + if (duplicateNum > 0) { + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(xNormWithBiasTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1); + PipeBarrier(); + } + xInQueue_.FreeTensor(xInLocalTensor); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::CopuOutXNorm(int64_t row) +{ + LocalTensor outOutTensor = outOutQueue_.AllocTensor(); + LocalTensor xNormTensor = xNormBuf_.Get(); + DataCopy(outOutTensor, xNormTensor, expertCountAlign_); + outOutQueue_.EnQue(outOutTensor); + outOutTensor = outOutQueue_.DeQue(); + DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(float)), 0, 0, 0}; + DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams); + outOutQueue_.FreeTensor(outOutTensor); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::SelectTopKExpertIdx() +{ + LocalTensor expertIdxOut = expertIdxOutQueue_.AllocTensor(); + LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); + LocalTensor expertIdTensor = expertIdBuf_.Get(); + LocalTensor topKExpertId = topKExpertIdBuf_.Get(); + LocalTensor sortedScore = calcTmpBuf_.Get(); + LocalTensor sortTmp = calcTmpBuf_.Get()[expertCountAlign_ * CONSTANT_TWO]; + PipeBarrier(); + Sort(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp, + expertCountAlign_ / ONE_REPEAT_SORT_NUM); + + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES); + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS; + gatherMaskParams.src1RepeatStride = 0; + + uint64_t rsvdCnt = 0; // Used to store the number of elements retained after filtering + uint8_t src1Pattern = 2; // Built-in fixed pattern + PipeBarrier(); + GatherMask(topKExpertId, sortedScore.template ReinterpretCast(), src1Pattern, false, + static_cast(0), gatherMaskParams, rsvdCnt); + + DataCopy(expertIdxOut, topKExpertId, expertCountAlign_); + expertIdxOutQueue_.EnQue(expertIdxOut); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::SelectTopKExpertScore() +{ + LocalTensor xNormTensor = xNormBuf_.Get(); + LocalTensor yOutTensor = yOutQueue_.AllocTensor(); + LocalTensor topKExpertId = topKExpertIdBuf_.Get(); + LocalTensor topKExpertIdWithByte = calcTmpBuf_.Get(); + PipeBarrier(); + Muls(topKExpertIdWithByte, topKExpertId, static_cast(sizeof(float)), k_); + PipeBarrier(); + Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast(), static_cast(0), + k_); + + bool needRenorm = (normType_ == 1 ) || // Case 1: sigmoid + renorm + (normType_ == 0 && renorm_ == 1); // Case 3: softmax + renorm + if (needRenorm == 1) { + LocalTensor maxValueTensor = calcTmpBuf_.Get(); + LocalTensor tmpTensor = calcTmpBuf_.Get()[BLOCK_BYTES]; + PipeBarrier(); + ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_); + event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); + float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps; + event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + SetFlag(eventIdSToV); + WaitFlag(eventIdSToV); + Duplicate(tmpTensor, sumValue, k_); + PipeBarrier(); + Div(yOutTensor, yOutTensor, tmpTensor, k_); + } + PipeBarrier(); + Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_); + + if constexpr (!IsSameType::value) { + PipeBarrier(); + Cast(yOutTensor.ReinterpretCast(), yOutTensor, RoundMode::CAST_RINT, k_); + } + + yOutQueue_.EnQue(yOutTensor); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::CopyOut(int64_t row) +{ + LocalTensor yOutTensor = yOutQueue_.DeQue(); + LocalTensor expertIdxOut = expertIdxOutQueue_.DeQue(); + DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; + DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); + dataCopyParams.blockLen = k_ * sizeof(int32_t); + DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams); + yOutQueue_.FreeTensor(yOutTensor); + expertIdxOutQueue_.FreeTensor(expertIdxOut); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, + GM_ADDR out, GM_ADDR workspace, + const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) +{ + tilingData_ = tilingData; + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + perCoreRowCount_ = tilingData_->perCoreRowCount; + if (blockIdx_ == GetBlockNum() - 1) { + curCoreRowCount_ = tilingData_->lastCoreRowCount; + } else { + curCoreRowCount_ = tilingData_->perCoreRowCount; + } + expertCount_ = tilingData_->expertCount; + addBias_ = tilingData_->addBias == 1; + outFlag_ = tilingData_->outFlag == 1; + k_ = tilingData_->k; + renorm_ = tilingData_->renorm; + normType_ = tilingData_->normType; + expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + + // init input gm buf + xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); + + // init output gm buf + yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); + outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); + + // init que + pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float)); + pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t)); + pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float)); + + // init calc buf + pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); + pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t)); + pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float)); + pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float)); + pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t)); + + // init tmp buf + pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT); +} + +template +__aicore__ inline void MoeGatingTopKWithoutGroup::Process() +{ + CopyInBiasAndInitExpertId(); + for (int64_t row = 0; row < curCoreRowCount_; row++) { + CopyInX(row); + ComputeX(); + if (outFlag_) { + CopuOutXNorm(row); + } + SelectTopKExpertIdx(); + SelectTopKExpertScore(); + CopyOut(row); + } +} +} // namespace MoeGatingTopK +#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h new file mode 100644 index 00000000..f7e7fce1 --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h @@ -0,0 +1,51 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling.h + * \brief + */ + +#pragma once + +#include +#include +#include "data_copy_transpose_tiling_def.h" + +namespace optiling { + +inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize, + optiling::CopyTransposeTiling &tiling) +{ + constexpr int64_t B_INDEX = 0; + constexpr int64_t N_INDEX = 1; + constexpr int64_t S_INDEX = 2; + constexpr int64_t H_INDEX = 3; + std::vector dstShapeInfo = dstShape.GetDims(); + std::vector srcShapeInfo = srcShape.GetDims(); + + tiling.set_dstShapeB(dstShapeInfo[B_INDEX]); + tiling.set_dstShapeN(dstShapeInfo[N_INDEX]); + tiling.set_dstShapeS(dstShapeInfo[S_INDEX]); + tiling.set_dstShapeH(dstShapeInfo[H_INDEX]); + tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN()); + + tiling.set_srcShapeB(srcShapeInfo[B_INDEX]); + tiling.set_srcShapeN(srcShapeInfo[N_INDEX]); + tiling.set_srcShapeS(srcShapeInfo[S_INDEX]); + tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]); + tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize); + tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH()); + tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS()); + tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN()); + tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH()); +} + +} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h new file mode 100644 index 00000000..18552c36 --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h @@ -0,0 +1,43 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling_def.h + * \brief + */ + +#pragma once + +#include +#include + +namespace optiling { + +BEGIN_TILING_DATA_DEF(CopyTransposeTiling) +TILING_DATA_FIELD_DEF(uint32_t, dstShapeB); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeS); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeH); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeB); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeN); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeS); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen); +TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue); +TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling); +TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue); +TILING_DATA_FIELD_DEF(uint32_t, paramsAlign); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling) + +} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/error_log.h b/csrc/moe_gating_top_k/tiling_base/error_log.h new file mode 100644 index 00000000..770bbfe8 --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/error_log.h @@ -0,0 +1,56 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +// Modify OP_TILING_CHECK macro to ensure proper handling of expressions +#define OP_CHECK_IF(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) + + + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if ((ptr) == nullptr) { \ + OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_base.h b/csrc/moe_gating_top_k/tiling_base/tiling_base.h new file mode 100644 index 00000000..f0bbbdcc --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/tiling_base.h @@ -0,0 +1,256 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_base.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include "tiling/platform/platform_ascendc.h" +#include "error_log.h" + +#ifdef ASCENDC_OP_TEST +#define ASCENDC_EXTERN_C extern "C" +#else +#define ASCENDC_EXTERN_C +#endif + +namespace Ops { +namespace Transformer { +namespace OpTiling { + +struct AiCoreParams { + uint64_t ubSize = 0; + uint64_t blockDim = 0; + uint64_t aicNum = 0; + uint64_t l1Size = 0; + uint64_t l0aSize = 0; + uint64_t l0bSize = 0; + uint64_t l0cSize = 0; +}; + +struct CompileInfoCommon { + uint32_t aivNum; + uint32_t aicNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + uint64_t l2CacheSize; + int64_t coreNum; + int32_t socVersion; + uint32_t rsvd; +}; + +struct FlashAttentionScoreGradCompileInfo { + uint32_t aivNum; + uint32_t aicNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + uint64_t l2CacheSize; + int64_t coreNum; + platform_ascendc::SocVersion socVersion; +}; + +struct FACompileInfoCommon { + uint32_t aivNum; + uint32_t aicNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + uint64_t l2CacheSize; + int64_t coreNum; + int32_t socVersion; + uint32_t rsvd; +}; + +class TilingBaseClass { +public: + explicit TilingBaseClass(gert::TilingContext* context) : context_(context) + {} + + virtual ~TilingBaseClass() = default; + + // Tiling execution framework + // 1. GRAPH_SUCCESS: Success, and no need to continue executing subsequent Tiling class implementations + // 2. GRAPH_FAILED: Failure, abort the entire Tiling process + // 3. GRAPH_PARAM_INVALID: This class does not support, need to continue executing other Tiling class implementations + ge::graphStatus DoTiling() + { + auto ret = GetShapeAttrsInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetPlatformInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + if (!IsCapable()) { + return ge::GRAPH_PARAM_INVALID; + } + ret = DoOpTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = DoLibApiTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetWorkspaceSize(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = PostTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + context_->SetTilingKey(GetTilingKey()); + DumpTilingInfo(); + return ge::GRAPH_SUCCESS; + } + + // Update context + virtual void Reset(gert::TilingContext* context) + { + context_ = context; + } + +protected: + virtual bool IsCapable() = 0; + // 1. Get platform information such as CoreNum, UB/L1/L0C resource sizes + virtual ge::graphStatus GetPlatformInfo() = 0; + // 2. Get INPUT/OUTPUT/ATTR information + virtual ge::graphStatus GetShapeAttrsInfo() = 0; + // 3. Calculate data splitting TilingData + virtual ge::graphStatus DoOpTiling() = 0; + // 4. Calculate high-level API TilingData + virtual ge::graphStatus DoLibApiTiling() = 0; + // 5. Calculate TilingKey + [[nodiscard]] virtual uint64_t GetTilingKey() const = 0; + // 6. Calculate Workspace size + virtual ge::graphStatus GetWorkspaceSize() = 0; + // 7. Save Tiling data + virtual ge::graphStatus PostTiling() = 0; + // 8. Dump Tiling data + virtual void DumpTilingInfo() + { + int32_t enable = CheckLogLevel(static_cast(OP), DLOG_DEBUG); + if (enable != 1) { + return; + } + auto buf = (uint32_t*)context_->GetRawTilingData()->GetData(); + auto bufLen = context_->GetRawTilingData()->GetDataSize(); + std::ostringstream oss; + oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen + << ", content:"; + for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) { + oss << *(buf + i) << ","; + if (oss.str().length() > 640) { // Split according to 640 to avoid truncation + OP_LOGD(context_, "%s", oss.str().c_str()); + oss.str(""); + } + } + OP_LOGD(context_, "%s", oss.str().c_str()); + } + + static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) + { + uint32_t ration; + if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) { + return sliceNum; + } + ration = aivCoreNum / aicCoreNum; + return (sliceNum + (ration - 1)) / ration; + } + + template + [[nodiscard]] std::string GetShapeDebugStr(const T& shape) const + { + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); + } + + [[nodiscard]] std::string GetTensorDebugStr( + const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor) + { + if (shape == nullptr || tensor == nullptr) { + return "nil "; + } + std::ostringstream oss; + oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),"; + oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),"; + oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),"; + oss << "(format: " + << ge::TypeUtils::FormatToSerialString( + static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) + << "),"; + oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") "; + return oss.str(); + } + + [[nodiscard]] std::string GetTilingContextDebugStr() + { + std::ostringstream oss; + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) { + oss << "input" << i << ": "; + oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i)); + } + + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) { + oss << "output" << i << ": "; + oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i)); + } + return oss.str(); + } + + [[nodiscard]] std::string GetTilingDataDebugStr() const + { + auto rawTilingData = context_->GetRawTilingData(); + auto rawTilingDataSize = rawTilingData->GetDataSize(); + auto data = reinterpret_cast(rawTilingData->GetData()); + size_t len = rawTilingDataSize / sizeof(int32_t); + std::ostringstream oss; + for (size_t i = 0; i < len; i++) { + oss << data[i] << ", "; + } + return oss.str(); + } + +protected: + gert::TilingContext* context_ = nullptr; + std::unique_ptr ascendcPlatform_{nullptr}; + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + AiCoreParams aicoreParams_; +}; + +} // namespace OpTiling +} // namespace Transformer +} // namespace Ops \ No newline at end of file diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_key.h b/csrc/moe_gating_top_k/tiling_base/tiling_key.h new file mode 100644 index 00000000..ddc105cf --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/tiling_key.h @@ -0,0 +1,63 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_key.h + * \brief + */ + +#pragma once + +#include + +namespace Ops { +namespace Transformer { +namespace OpTiling { +constexpr uint64_t RecursiveSum() +{ + return 0; +} + +constexpr uint64_t kBase = 10; // Base-10 carry base +template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) +{ + return static_cast(templateId) + kBase * RecursiveSum(templateIds...); +} + +// TilingKey generation rules: +// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1, +// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1: +// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting, +// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit; +// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit; +// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit +// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit +// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit +// For other specialized scenarios, define your own bit fields and values +// usage: get tilingKey from inputed types +// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, +// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) + +constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 +template constexpr uint64_t GET_TILINGKEY(Args... templateIds) +{ + return TILINGKEYOFFSET + RecursiveSum(templateIds...); +} + +// usage: get tilingKey from inputed types +// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) + +#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ + (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ + SparseEnum::sparse)) + +} // namespace Optiling +} // namespace Transformer +} // namespace Ops diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h b/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h new file mode 100644 index 00000000..cbf4785a --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h @@ -0,0 +1,351 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_templates_registry.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include "exe_graph/runtime/tiling_context.h" +#include "tiling_base.h" +#include "error_log.h" + +namespace Ops { +namespace Transformer { +namespace OpTiling { + +template +std::unique_ptr TILING_CLASS(gert::TilingContext* context) +{ + return std::unique_ptr(new (std::nothrow) T(context)); +} + +using TilingClassCase = std::unique_ptr (*)(gert::TilingContext*); + +class TilingCases { +public: + explicit TilingCases(std::string op_type) : op_type_(std::move(op_type)) + {} + + template + void AddTiling(int32_t priority) + { + OP_CHECK_IF( + cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return); + cases_[priority] = TILING_CLASS; + OP_CHECK_IF( + cases_[priority] == nullptr, + OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return); + } + + const std::map& GetTilingCases() + { + return cases_; + } + +private: + std::map cases_; + const std::string op_type_; +}; + +// --------------------------------Interfacce with soc version -------------------------------- +class TilingRegistryNew { +public: + TilingRegistryNew() = default; + +#ifdef ASCENDC_OP_TEST + static TilingRegistryNew& GetInstance(); +#else + static TilingRegistryNew& GetInstance() + { + static TilingRegistryNew registry_impl_; + return registry_impl_; + } +#endif + + std::shared_ptr RegisterOp(const std::string& op_type, int32_t soc_version) + { + auto soc_iter = registry_map_.find(soc_version); + if (soc_iter == registry_map_.end()) { + std::map> op_type_map; + op_type_map[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); + registry_map_[soc_version] = op_type_map; + } else { + if (soc_iter->second.find(op_type) == soc_iter->second.end()) { + soc_iter->second[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); + } + } + + OP_CHECK_IF( + registry_map_[soc_version][op_type] == nullptr, + OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); + return registry_map_[soc_version][op_type]; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext* context) + { + int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION; + const char* op_type = context->GetNodeType(); + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + if (platformInfoPtr == nullptr) { + auto compileInfoPtr = static_cast(context->GetCompileInfo()); + OP_CHECK_IF( + compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); + soc_version = compileInfoPtr->socVersion; + OP_LOGD(context, "soc version in compileInfo is %d", soc_version); + } else { + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + soc_version = static_cast(ascendcPlatform.GetSocVersion()); + OP_LOGD(context, "soc version is %d", soc_version); + if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) { + OP_LOGE(op_type, "Do op tiling failed, cannot find soc version."); + return ge::GRAPH_FAILED; + } + } + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); + for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { + auto tilingTemplate = it->second(context); + if (tilingTemplate != nullptr) { + ge::graphStatus status = tilingTemplate->DoTiling(); + if (status != ge::GRAPH_PARAM_INVALID) { + OP_LOGD(context, "Do general op tiling success priority=%d", it->first); + return status; + } + OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); + } + } + OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); + return ge::GRAPH_FAILED; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) + { + int32_t soc_version; + const char* op_type = context->GetNodeType(); + auto platformInfoPtr = context->GetPlatformInfo(); + if (platformInfoPtr == nullptr) { + auto compileInfoPtr = reinterpret_cast(context->GetCompileInfo()); + OP_CHECK_IF( + compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); + soc_version = compileInfoPtr->socVersion; + OP_LOGD(context, "soc version in compileInfo is %d", soc_version); + } else { + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + soc_version = static_cast(ascendcPlatform.GetSocVersion()); + OP_LOGD(context, "soc version is %d", soc_version); + } + + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); + for (auto priority_id : priorities) { + auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id); + if (tilingCaseIter != tilingTemplateRegistryMap.end()) { + auto templateFunc = tilingCaseIter->second(context); + if (templateFunc != nullptr) { + ge::graphStatus status = templateFunc->DoTiling(); + if (status == ge::GRAPH_SUCCESS) { + OP_LOGD(context, "Do general op tiling success priority=%d", priority_id); + return status; + } + OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id); + } + } + } + return ge::GRAPH_FAILED; + } + + const std::map& GetTilingTemplates(const std::string& op_type, int32_t soc_version) + { + auto soc_iter = registry_map_.find(soc_version); + OP_CHECK_IF( + soc_iter == registry_map_.end(), + OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version), + return empty_tiling_case_); + auto op_iter = soc_iter->second.find(op_type); + OP_CHECK_IF( + op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), + return empty_tiling_case_); + return op_iter->second->GetTilingCases(); + } + +private: + std::map>> registry_map_; // key is socversion + const std::map empty_tiling_case_{}; +}; + +class RegisterNew { +public: + explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type)) + {} + + template + RegisterNew& tiling(int32_t priority, int32_t soc_version) + { + auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); + OP_CHECK_IF( + tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); + tilingCases->AddTiling(priority); + return *this; + } + + template + RegisterNew& tiling(int32_t priority, const std::vector& soc_versions) + { + for (int32_t soc_version : soc_versions) { + auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); + OP_CHECK_IF( + tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), + return *this); + tilingCases->AddTiling(priority); + } + return *this; + } + +private: + const std::string op_type_; +}; + +// --------------------------------Interfacce without soc version -------------------------------- +class TilingRegistry { +public: + TilingRegistry() = default; + +#ifdef ASCENDC_OP_TEST + static TilingRegistry& GetInstance(); +#else + static TilingRegistry& GetInstance() + { + static TilingRegistry registry_impl_; + return registry_impl_; + } +#endif + + std::shared_ptr RegisterOp(const std::string& op_type) + { + if (registry_map_.find(op_type) == registry_map_.end()) { + registry_map_[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); + } + OP_CHECK_IF( + registry_map_[op_type] == nullptr, + OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); + return registry_map_[op_type]; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext* context) + { + const char* op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { + auto tilingTemplate = it->second(context); + if (tilingTemplate != nullptr) { + ge::graphStatus status = tilingTemplate->DoTiling(); + if (status != ge::GRAPH_PARAM_INVALID) { + OP_LOGD(context, "Do general op tiling success priority=%d", it->first); + return status; + } + OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); + } + } + OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); + return ge::GRAPH_FAILED; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) + { + const char* op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto priorityId : priorities) { + auto templateFunc = tilingTemplateRegistryMap[priorityId](context); + if (templateFunc != nullptr) { + ge::graphStatus status = templateFunc->DoTiling(); + if (status == ge::GRAPH_SUCCESS) { + OP_LOGD(context, "Do general op tiling success priority=%d", priorityId); + return status; + } + if (status != ge::GRAPH_PARAM_INVALID) { + OP_LOGD(context, "Do op tiling failed"); + return status; + } + OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId); + } + } + OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); + return ge::GRAPH_FAILED; + } + + const std::map& GetTilingTemplates(const std::string& op_type) + { + OP_CHECK_IF( + registry_map_.find(op_type) == registry_map_.end(), + OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_); + return registry_map_[op_type]->GetTilingCases(); + } + +private: + std::map> registry_map_; + const std::map empty_tiling_case_; +}; + +class Register { +public: + explicit Register(std::string op_type) : op_type_(std::move(op_type)) + {} + + template + Register& tiling(int32_t priority) + { + auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_); + OP_CHECK_IF( + tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); + tilingCases->AddTiling(priority); + return *this; + } + +private: + const std::string op_type_; +}; +} // namespace OpTiling +} // namespace Transformer +} // namespace Ops + +// op_type: operator name, class_name: registered tiling class, soc_version: chip version number +// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first +#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \ + [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ + static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ + Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_versions) + +// op_type: operator name, class_name: registered tiling class +// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected +#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \ + [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ + static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \ + Ops::Transformer::OpTiling::Register(op_type).tiling(priority) + +// op_type: operator name, class_name: registered tiling class +// soc_version: SOC version, used to distinguish different SOCs +// priority: priority of tiling class, smaller value means higher priority, i.e., this tiling class will be selected first +#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \ + [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ + static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ + Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_version) + +// op_type: operator name, class_name: registered tiling class +// priority: priority of tiling class, smaller value means higher priority, i.e., higher probability of being selected +// Replaces REGISTER_TILING_TEMPLATE, if op_type is a string constant, remove the quotes +#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \ + [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ + static Ops::Transformer::OpTiling::Register \ + __attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \ + Ops::Transformer::OpTiling::Register(#op_type).tiling(priority) diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_type.h b/csrc/moe_gating_top_k/tiling_base/tiling_type.h new file mode 100644 index 00000000..1f4b29c7 --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/tiling_type.h @@ -0,0 +1,139 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_type.h + * \brief + */ + +#pragma once + +#include + +namespace optiling { + +enum class AxisEnum { + B = 0, + N2 = 1, + G = 2, + S1 = 3, + S2 = 4, + D = 5, + NONE = 9, +}; + +enum class DtypeEnum { + FLOAT16 = 0, + FLOAT32 = 1, + BFLOAT16 = 2, + FLOAT16_PRECISION = 3, +}; + +enum class PerformanceOrientedEnum { + BIG_BUFFER = 1, + BIG_DOUBLE_BUFFER = 2, +}; + +enum class MatmulConfig { + NULL_CONFIG = 0, + NORMAL_CONFIG = 1, + MDL_CONFIG = 2 +}; + +enum class PseConfig { + NO_PSE = 0, + EXIST_PSE = 1 +}; + +enum class AttenMaskConfig { + NO_ATTEN_MASK = 0, + EXIST_ATTEN_MASK = 1 +}; + +enum class DropOutConfig { + NO_DROP_OUT = 0, + EXIST_DROP_OUT = 1 +}; + +enum class CubeFormatEnum { + ND = 0, + NZ = 1 +}; +enum class LayoutEnum { + BSND = 0, + SBND = 1, + BNSD = 2, + TND = 3, + NTD_TND = 4 +}; + +enum class CubeInputSourceEnum { + GM = 0, + L1 = 1 +}; + +enum class OptionEnum { + DISABLE = 0, + ENABLE = 1 +}; + +enum class SparseEnum { + ALL = 0, + NONE = 1, + ANY = 2, + CAUSAL = 3, + BAND = 4, + PREFIX = 5, + BAND_COMPRESS = 6, + RIGHT_DOWN_CAUSAL = 7, + RIGHT_DOWN_CAUSAL_BAND = 8, + BAND_LEFT_UP_CAUSAL = 9 +}; + +constexpr uint64_t RecursiveSum() +{ + return 0; +} + +constexpr int64_t base10Multiplier = 10; + +template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) +{ + return static_cast(templateId) + base10Multiplier * RecursiveSum(templateIds...); +} + +// TilingKey generation rules: +// FlashAttentionScore/FlashAttentionScoreGrad assembles tiling key using decimal digits, containing the following key parameters from low to high: Ub0, Ub1, +// Block, DataType, Format, Sparse. Specialized template Ub0, Ub1: +// Represents the axis for UB intra-core splitting, using AxisEnum. Since we allow at most two axes to be split, UB0 and UB1 exist. If there is no UB intra-core splitting, +// fill with AXIS_NONE. UB0 and UB1 each occupy one decimal digit; +// Block: Represents the axis used by UB for multi-core splitting, using AxisEnum, occupies one decimal digit; +// DataType: Represents the input/output data types supported by the current tiling key, using SupportedDtype enum, occupies one decimal digit +// Format: Represents the Format supported by the current tiling key, using InputLayout enum, occupies one decimal digit +// Sparse: Represents whether the current tiling key supports Sparse, using SparseCapability enum, occupies one decimal digit +// For other specialized scenarios, define your own bit fields and values +// usage: get tilingKey from inputed types +// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, +// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) + +constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 +template constexpr uint64_t GET_TILINGKEY(Args... templateIds) +{ + return TILINGKEYOFFSET + RecursiveSum(templateIds...); +} + +// usage: get tilingKey from inputed types +// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) + +#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ + (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ + SparseEnum::sparse)) + +} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_util.h b/csrc/moe_gating_top_k/tiling_base/tiling_util.h new file mode 100644 index 00000000..fb6ffa2d --- /dev/null +++ b/csrc/moe_gating_top_k/tiling_base/tiling_util.h @@ -0,0 +1,30 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_util.h + * \brief + */ + +#pragma once + +#include "register/op_impl_registry.h" + +namespace Ops { +namespace Transformer { +namespace OpTiling { +bool IsRegbaseSocVersion(const gert::TilingParseContext* context); + +bool IsRegbaseSocVersion(const gert::TilingContext* context); + +const gert::Shape& EnsureNotScalar(const gert::Shape& inShape); +} // namespace OpTiling +} // namespace Transformer +} // namespace Ops \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 5ff705aa..ca701235 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1219,10 +1219,65 @@ std::tuple npu_moe_init_routing_ return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale); } +std::tuple moe_gating_top_k( + const at::Tensor& x, + int64_t k, + int64_t k_group, + int64_t group_count, + int64_t group_select_mode, + int64_t renorm, + int64_t norm_type, + bool out_flag, + double routed_scaling_factor, + double eps, + const c10::optional& bias_opt + ) +{ + TORCH_CHECK(x.dim() == 2, "The x should be 2D"); + TORCH_CHECK( + x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, + "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", + x.scalar_type()); + + auto x_size = x.sizes(); + auto rows = x_size[0]; + auto expert_num = x_size[1]; + const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); }); + if (bias.defined()) { + TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); + TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); + auto bias_size = bias.sizes(); + TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); + } + at::Tensor y = at::empty({rows, k}, x.options()); + at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt)); + at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); + + EXEC_NPU_CMD(aclnnMoeGatingTopK, + x, + bias, + k, + k_group, + group_count, + group_select_mode, + renorm, + norm_type, + out_flag, + routed_scaling_factor, + eps, + y, + expert_idx, + out + ); + + return std::tuple(y,expert_idx,out); +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) { + // vLLM-Ascend custom ops ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor); @@ -1358,6 +1413,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "num_ranks) -> Tensor"); ops.impl("combine_prefill", torch::kPrivateUse1, &vllm_ascend::combine_prefill); + ops.def( "npu_moe_init_routing_custom(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, " " int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, " @@ -1365,4 +1421,21 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)" ); ops.impl("npu_moe_init_routing_custom", torch::kPrivateUse1, &vllm_ascend::npu_moe_init_routing_custom); + // vLLM-Ascend custom ops + ops.def( + "moe_gating_top_k(Tensor x, " + "int k, " + "int k_group, " + "int group_count, " + "int group_select_mode, " + "int renorm, " + "int norm_type, " + "bool out_flag, " + "float routed_scaling_factor, " + "float eps," + "Tensor? bias_opt=None)" + + "-> (Tensor y ,Tensor expert_idx, Tensor out)" + ); + ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index da902ed4..a166ba16 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -366,7 +366,43 @@ std::tuple npu_moe_init_routing_ at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat)); return {expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale}; } +std::tuple moe_gating_top_k_meta( + const at::Tensor& x, + int64_t k, + int64_t k_group, + int64_t group_count, + int64_t group_select_mode, + int64_t renorm, + int64_t norm_type, + bool out_flag, + double routed_scaling_factor, + double eps, + const c10::optional& bias_opt + + ) +{ + TORCH_CHECK(x.dim() == 2, "The x should be 2D"); + TORCH_CHECK( + x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, + "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", + x.scalar_type()); + auto x_size = x.sizes(); + auto rows = x_size[0]; + auto expert_num = x_size[1]; + const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); }); + if (bias.defined()) { + TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); + TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); + auto bias_size = bias.sizes(); + TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); + } + at::Tensor y = at::empty({rows, k}, x.options()); + at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt)); + at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); + + return std::tuple(y,expert_idx,out); +} } // namespace meta } // namespace vllm_ascend @@ -374,6 +410,7 @@ namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { + // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -402,5 +439,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta); // moe_init_routing_custom ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta); + // Moe_gating_top_k + ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta); } } diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py index 8a162b91..369be649 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py @@ -305,8 +305,8 @@ def test_select_experts( ) call_moe_gatingtopk = check_npu_moe_gating_top_k( - hidden_states, topk, topk_group, num_expert_group, scoring_func, - custom_routing_function) + hidden_states, topk, renormalize, topk_group, num_expert_group, + scoring_func, custom_routing_function) if not call_moe_gatingtopk and use_grouped_topk: mock_native_grouped_topk.assert_called_once() else: diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_npu_moe_gating_top_k.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_npu_moe_gating_top_k.py new file mode 100644 index 00000000..7928d539 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_npu_moe_gating_top_k.py @@ -0,0 +1,210 @@ +import random + +import numpy +import pytest +import torch + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +# Fix random seed to ensure test reproducibility +RTOL_TOLERANCE = 1e-5 +ATOL_TOLERANCE = 1e-8 +seed = 45 +random.seed(seed) +numpy.random.seed(seed) +torch.manual_seed(seed) + + +def softmax_func(x, axis=None): + """Softmax implementation (adapted for numpy calculation)""" + if "float16" in x.dtype.name: + x = x.astype(numpy.float32) + x_max = x.max(axis=axis, keepdims=True) + x_sub = x - x_max + y = numpy.exp(x_sub) + x_sum = y.sum(axis=axis, keepdims=True) + res = y / x_sum + return res, x_max, x_sum + + +def moe_gating_top_k_numpy_ref(x: torch.Tensor, + k: int, + bias: torch.Tensor | None, + k_group: int = 1, + group_count: int = 1, + group_select_mode: int = 0, + renorm: int = 0, + norm_type: int = 0, + y2_flag: bool = False, + routed_scaling_factor: float = 1.0, + eps: float = 1e-20) -> tuple: + """NumPy reference implementation of MOE Gating TopK. + + For result comparison with NPU operator, ensure the consistency + between NPU kernel and baseline implementation. + + Args: + x: Input tensor of shape (num_tokens, num_experts) + k: Number of top-k experts to select + bias: Bias tensor of shape (num_experts,) (optional) + k_group: Number of top-k groups to select + group_count: Number of expert groups + group_select_mode: Group selection mode (0: max, 1: top2 sum) + renorm: Whether to renormalize the output (0/1) + norm_type: Normalization type (0: softmax, 1: sigmoid) + y2_flag: Whether to output original x as y2 + routed_scaling_factor: Scaling factor for routing weights + eps: Small epsilon to avoid division by zero + + Returns: + tuple: (y, indices, y2) + - y: Top-k weights of shape (num_tokens, k) + - indices: Top-k expert indices of shape (num_tokens, k) + - y2: Original x if y2_flag is True, else None + """ + dtype = x.dtype + if dtype != torch.float32: + x = x.to(dtype=torch.float32) + if bias is not None: + bias = bias.to(dtype=torch.float32) + + x = x.numpy() + if bias is not None: + bias = bias.numpy() + + if norm_type == 0: # softmax normalization + x, _, _ = softmax_func(x, -1) + else: # sigmoid normalization + x = 1 / (1 + numpy.exp(-x)) + + original_x = x + if bias is not None: + x = x + bias + + if group_count > 1: + x = x.reshape(x.shape[0], group_count, -1) + if group_select_mode == 0: + group_x = numpy.amax(x, axis=-1) + else: + group_x = numpy.partition(x, -2, axis=-1)[..., -2:].sum(axis=-1) + indices = numpy.argsort(-group_x, axis=-1, kind='stable')[:, :k_group] + + mask = numpy.ones((x.shape[0], group_count), dtype=bool) + mask[numpy.arange(x.shape[0])[:, None], indices] = False + x = numpy.where(mask[..., None], float('-inf'), x) + x = x.reshape(x.shape[0], -1) + + _, indices = torch.sort(torch.from_numpy(x), + dim=-1, + stable=True, + descending=True) + indices = numpy.asarray(indices[:, :k]) + + y = numpy.take_along_axis(original_x, indices, axis=1) + if norm_type == 1 or renorm == 1: + y /= (numpy.sum(y, axis=-1, keepdims=True) + eps) + y *= routed_scaling_factor + + y2 = original_x if y2_flag else None + y = torch.tensor(y, dtype=dtype) + return y, indices.astype(numpy.int32), y2 + + +# pytest parameterized decorators (cover all test scenarios) +@pytest.mark.parametrize("group_select_mode", [0, 1]) +@pytest.mark.parametrize("renorm", [1]) +@pytest.mark.parametrize("norm_type", [0, 1]) +@pytest.mark.parametrize("group_count", [1, 8]) +@pytest.mark.parametrize("k_ranges", [4, 8, 12, 16, 6, 32]) +@pytest.mark.parametrize("x_dim0_range", list(range(1, 17))) +@pytest.mark.parametrize("x_dim1_range", [256, 128, 64, 208, 192, 160]) +def test_npu_moe_gating_topk_compare(group_select_mode: int, + renorm: int, + norm_type: int, + group_count: int, + k_ranges: int, + x_dim0_range: int, + x_dim1_range: int, + device: str = "npu"): + """Ascend NPU MOE Gating TopK operator test. + + Compare NPU kernel results with NumPy reference implementation + to verify the correctness of Ascend custom op. + + Args: + group_select_mode: Group selection mode (0: max, 1: top2 sum) + renorm: Whether to renormalize output (fixed to 1 in test) + norm_type: Normalization type (0: softmax, 1: sigmoid) + group_count: Number of expert groups + k_ranges: Number of top-k experts to select + x_dim0_range: First dimension of input tensor (num_tokens) + x_dim1_range: Second dimension of input tensor (num_experts) + device: Target device (fixed to "npu" in test) + """ + # Simplify parameter names for better readability + k = k_ranges + dim0 = x_dim0_range + dim1 = x_dim1_range + + # Skip invalid cases: k cannot exceed num_experts per group + if k > dim1 // group_count: + return + + # Construct test inputs + x = numpy.random.uniform(-2, 2, (dim0, dim1)).astype(numpy.float32) + bias = numpy.random.uniform(-2, 2, (dim1, )).astype(numpy.float32) + + x_tensor = torch.tensor(x, dtype=torch.float32) + bias_tensor = torch.tensor(bias, dtype=torch.float32) + # Fix k_group value to avoid irreproducibility caused by random.randint + k_group = min(1, group_count) + out_flag = False + routed_scaling_factor = 1.0 + eps = 1e-20 + + # Calculate NumPy reference results + y, expert_idx, out = moe_gating_top_k_numpy_ref( + x_tensor, + k=k, + bias=bias_tensor, + k_group=k_group, + group_count=group_count, + group_select_mode=group_select_mode, + renorm=renorm, + norm_type=norm_type, + y2_flag=out_flag, + routed_scaling_factor=routed_scaling_factor, + eps=eps, + ) + + # Calculate NPU operator results + y_npu, expert_idx_npu, out_npu = torch.ops._C_ascend.moe_gating_top_k( + x_tensor.npu(), + k=k, + k_group=k_group, + group_count=group_count, + group_select_mode=group_select_mode, + renorm=renorm, + norm_type=norm_type, + out_flag=out_flag, + routed_scaling_factor=routed_scaling_factor, + eps=eps, + bias_opt=bias_tensor.npu() if bias_tensor is not None else None, + ) + + # Verify consistency between NPU and NumPy results + assert numpy.allclose(y.cpu().numpy(), + y_npu.cpu().numpy(), + rtol=RTOL_TOLERANCE, + atol=ATOL_TOLERANCE) + assert numpy.allclose(expert_idx, + expert_idx_npu.cpu().numpy(), + rtol=RTOL_TOLERANCE, + atol=ATOL_TOLERANCE) + + +if __name__ == "__main__": + # Execute pytest tests with verbose output + pytest.main([__file__, "-sv"]) diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 39200a86..07b611e7 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -17,7 +17,6 @@ from typing import Callable, Optional import torch -import torch_npu from vllm_ascend.utils import get_weight_prefetch_method @@ -64,6 +63,7 @@ def select_experts(hidden_states: torch.Tensor, is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( hidden_states=hidden_states, top_k=top_k, + renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, @@ -102,10 +102,13 @@ def select_experts(hidden_states: torch.Tensor, def check_npu_moe_gating_top_k( hidden_states: torch.Tensor, top_k: int, + renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None): + if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch + return False if custom_routing_function is not None: return False if scoring_func != "softmax" and scoring_func != "sigmoid": @@ -209,26 +212,25 @@ def _select_experts_with_fusion_ops( topk_group = topk_group if topk_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1 + renorm = int(renormalize) norm_type = 0 if scoring_func == "softmax" else 1 if e_score_correction_bias is not None and \ e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to( router_logits.dtype) - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k( router_logits, k=top_k, - bias=e_score_correction_bias, k_group=topk_group, group_count=num_expert_group, - group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + group_select_mode=1, + renorm=renorm, norm_type=norm_type, # 0: softmax; 1: sigmoid - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output + out_flag=False, routed_scaling_factor=routed_scaling_factor, - eps=float(1e-20)) - if scoring_func == "softmax": - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + eps=float(1e-20), + bias_opt=e_score_correction_bias, + ) return topk_weights, topk_ids