From 71f729a6617383aa510240742f41724531d5fc37 Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:05:47 +0800 Subject: [PATCH] Revert "moe_gating_top_k" (#5512) Reverts vllm-project/vllm-ascend#5271 It breaks e2e test - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 --- csrc/build_aclnn.sh | 3 +- csrc/moe_gating_top_k/op_host/CMakeLists.txt | 43 -- csrc/moe_gating_top_k/op_host/error_log.h | 56 -- csrc/moe_gating_top_k/op_host/math_util.h | 61 -- .../op_host/moe_gating_top_k_def.cpp | 71 -- .../op_host/moe_gating_top_k_infershape.cpp | 147 ---- .../op_host/moe_gating_top_k_proto.cpp | 15 - .../op_host/moe_gating_top_k_proto.h | 66 -- .../op_host/moe_gating_top_k_tiling.cpp | 580 --------------- .../op_host/moe_gating_top_k_tiling.h | 86 --- .../moe_gating_top_k_tiling_arch35.cpp | 521 -------------- .../op_host/moe_gating_top_k_tiling_base.cpp | 38 - csrc/moe_gating_top_k/op_kernel/common.h | 89 --- csrc/moe_gating_top_k/op_kernel/error_log.h | 56 -- .../op_kernel/moe_gating_top_k.cpp | 63 -- .../op_kernel/moe_gating_top_k_apt.cpp | 46 -- .../op_kernel/moe_gating_top_k_e_k_fullload.h | 404 ----------- .../op_kernel/moe_gating_top_k_generalized.h | 669 ------------------ .../moe_gating_top_k_without_group.h | 338 --------- .../tiling_base/data_copy_transpose_tiling.h | 51 -- .../data_copy_transpose_tiling_def.h | 43 -- csrc/moe_gating_top_k/tiling_base/error_log.h | 56 -- .../tiling_base/tiling_base.h | 256 ------- .../moe_gating_top_k/tiling_base/tiling_key.h | 63 -- .../tiling_base/tiling_templates_registry.h | 351 --------- .../tiling_base/tiling_type.h | 139 ---- .../tiling_base/tiling_util.h | 30 - csrc/torch_binding.cpp | 73 +- csrc/torch_binding_meta.cpp | 39 - .../test_dispatch_gmm_combine_decode.py | 9 +- .../nightly/ops/test_npu_moe_gating_top_k.py | 322 --------- .../netloader/test_netloader_elastic.py | 2 +- tests/ut/worker/test_worker_v1.py | 2 +- vllm_ascend/ops/fused_moe/experts_selector.py | 25 +- 34 files changed, 22 insertions(+), 4791 deletions(-) delete mode 100644 csrc/moe_gating_top_k/op_host/CMakeLists.txt delete mode 100644 csrc/moe_gating_top_k/op_host/error_log.h delete mode 100644 csrc/moe_gating_top_k/op_host/math_util.h delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp delete mode 100644 csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp delete mode 100644 csrc/moe_gating_top_k/op_kernel/common.h delete mode 100644 csrc/moe_gating_top_k/op_kernel/error_log.h delete mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp delete mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp delete mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h delete mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h delete mode 100644 csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/error_log.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_base.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_key.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_type.h delete mode 100644 csrc/moe_gating_top_k/tiling_base/tiling_util.h delete mode 100644 tests/e2e/nightly/ops/test_npu_moe_gating_top_k.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 9aec1da5..f24a9f02 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -70,7 +70,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "dispatch_layout" "notify_dispatch" "moe_init_routing_custom" - "moe_gating_top_k" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/moe_gating_top_k/op_host/CMakeLists.txt b/csrc/moe_gating_top_k/op_host/CMakeLists.txt deleted file mode 100644 index f29c92ac..00000000 --- a/csrc/moe_gating_top_k/op_host/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -# ---------------------------------------------------------------------------- -# This program is free software, you can redistribute it and/or modify. -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ---------------------------------------------------------------------------- - -add_ops_compile_options( - OP_NAME MoeGatingTopK - OPTIONS --cce-auto-sync=on - -Wno-deprecated-declarations - -Werror -) - -# Host 侧算子实现(aclnn) -if (BUILD_OPEN_PROJECT) - target_sources(op_host_aclnn PRIVATE - moe_gating_top_k_def.cpp - - ) - - - - - - # Tiling 模块 - target_sources(optiling PRIVATE - moe_gating_top_k_tiling.cpp - moe_gating_top_k_tiling_base.cpp - moe_gating_top_k_tiling_arch35.cpp - ) - target_sources(opsproto PRIVATE - moe_gating_top_k_proto.cpp - moe_gating_top_k_infershape.cpp - - ) - target_include_directories(optiling PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} - ) -endif() diff --git a/csrc/moe_gating_top_k/op_host/error_log.h b/csrc/moe_gating_top_k/op_host/error_log.h deleted file mode 100644 index 8b64eced..00000000 --- a/csrc/moe_gating_top_k/op_host/error_log.h +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ -#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ - -#include -#include "toolchain/slog.h" - -#define OP_LOGI(opname, ...) -#define OP_LOGW(opname, ...) \ - do { \ - printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ - do { \ - printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE(opname, ...) \ - do { \ - printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGD(opname, ...) - -namespace optiling { - -#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ - do { \ - OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ - } while (0) - -// 修改 OP_TILING_CHECK 宏,确保正确处理表达式 -#define OP_CHECK_IF(cond, log_func, expr) \ - do { \ - if (cond) { \ - log_func; \ - expr; \ - } \ - } while (0) - - - -#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ - do { \ - if ((ptr) == nullptr) { \ - OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -} // namespace optiling - -#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/op_host/math_util.h b/csrc/moe_gating_top_k/op_host/math_util.h deleted file mode 100644 index edc1c8ea..00000000 --- a/csrc/moe_gating_top_k/op_host/math_util.h +++ /dev/null @@ -1,61 +0,0 @@ -/** -* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This program is free software, you can redistribute it and/or modify it under the terms and conditions of -* CANN Open Software License Agreement Version 2.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -*/ - -/*! - * \file math_util.h - * \brief - */ - -#ifndef TILING_MATMUL_MATH_UTIL_H -#define TILING_MATMUL_MATH_UTIL_H - -#include -#include -#include -#include -namespace matmul_tiling { -class MathUtil { -public: - static bool IsEqual(float leftValue, float rightValue); - template - static auto CeilDivision(T num1, T num2) -> T - { - if (num2 == 0) { - return 0; - } - return static_cast((static_cast(num1) + static_cast(num2) - 1) / - static_cast(num2)); - } - template - static auto Align(T num1, T num2) -> T - { - return CeilDivision(num1, num2) * num2; - } - static int32_t AlignDown(int32_t num1, int32_t num2); - static bool CheckMulOverflow(int32_t a, int32_t b, int32_t &c); - static int32_t MapShape(int32_t shape, bool roundUpFlag = true); - static void AddFactor(std::vector &dimsFactors, int32_t dim); - static void GetFactorCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart, - const int32_t factorEnd); - static void GetFactorLayerCnt(const int32_t shape, int32_t &factorCnt, const int32_t factorStart, - const int32_t factorEnd); - static bool CheckFactorNumSatisfy(const int32_t dim); - static int32_t FindBestSingleCore(const int32_t oriShape, const int32_t mappedShape, const int32_t coreNum, - bool isKDim); - static void GetFactors(std::vector &factorList, int32_t srcNum, int32_t minFactor, int32_t maxFactor); - static void GetFactors(std::vector &factorList, int32_t srcNum, int32_t maxFactor); - static void GetBlockFactors(std::vector &factorList, const int32_t oriShape, const int32_t mpShape, - const int32_t coreNum, const int32_t maxNum); - static int32_t GetNonFactorMap(std::vector &factorList, int32_t srcNum, int32_t maxFactor); - static std::vector> GetFactorPairs(int32_t num); - static std::pair DivideIntoMainAndTail(int32_t num, int32_t divisor); -}; -} // namespace matmul_tiling -#endif // _MATH_UTIL_H_ diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp deleted file mode 100644 index 52ecf37d..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_def.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_def.cpp - * \brief - */ -#include "register/op_def_registry.h" - -namespace ops { -class MoeGatingTopK : public OpDef { -public: - explicit MoeGatingTopK(const char *name) : OpDef(name) - { - this->Input("x") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Input("bias") - .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .AutoContiguous(); - this->Output("y") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("expert_idx") - .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("out") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Attr("k").Int(); - this->Attr("k_group").AttrType(OPTIONAL).Int(1); - this->Attr("group_count").AttrType(OPTIONAL).Int(1); - this->Attr("group_select_mode").AttrType(OPTIONAL).Int(0); - this->Attr("renorm").AttrType(OPTIONAL).Int(0); - this->Attr("norm_type").AttrType(OPTIONAL).Int(0); - this->Attr("out_flag").AttrType(OPTIONAL).Bool(false); - this->Attr("routed_scaling_factor").AttrType(OPTIONAL).Float(1.0); - this->Attr("eps").AttrType(OPTIONAL).Float(1e-20f); - this->AICore().AddConfig("ascend910b"); - this->AICore().AddConfig("ascend910_93"); - - OpAICoreConfig regbaseCfg; - regbaseCfg.DynamicCompileStaticFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .ExtendCfgInfo("opFile.value", "moe_gating_top_k_apt"); - this->AICore().AddConfig("ascend910_95", regbaseCfg); - } -}; - -OP_ADD(MoeGatingTopK); -} // namespace ops \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp deleted file mode 100644 index 37b72d1c..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_infershape.cpp +++ /dev/null @@ -1,147 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/* ! - * \file moe_gating_top_k_infershape.cpp - * \brief - */ - -#include "exe_graph/runtime/infer_shape_context.h" -#include "register/op_impl_registry.h" -#include "error_log.h" - -#include - -#include -#define TO_STRING(x) std::string(#x) - -using namespace ge; -namespace ops { -static constexpr size_t DIM_ONE = 1; -static constexpr size_t DIM_TWO = 2; -static constexpr int64_t NEG_ONE = -1; -static constexpr int64_t X_INDEX = 0; -static constexpr int64_t BIAS_INDEX = 1; -static constexpr int64_t Y_INDEX = 0; -static constexpr int64_t EXPERT_IDX_INDEX = 1; -static constexpr int64_t OUT_INDEX = 2; - -static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape) -{ - int64_t XRows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0); - int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1); - if (XRows < NEG_ONE || expertNum < NEG_ONE) { - OP_LOGE(context, "Invalid x shape, shape is %s.", TO_STRING(*xShape).c_str()); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus CheckInputDimsAndAttr(gert::InferShapeContext *context, const gert::Shape *xShape, - const int64_t k) -{ - if (xShape->GetDimNum() == 1U) { - if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) { - OP_LOGE(context, "The dynamic dim of x should be -2, current shape is %s.", - TO_STRING(*xShape).c_str()); - return ge::GRAPH_FAILED; - } - } else if (xShape->GetDimNum() != DIM_TWO) { - OP_LOGE(context, "The dim of x should be 2 or dynamic, current shape is %s.", - TO_STRING(*xShape).c_str()); - return ge::GRAPH_FAILED; - } - - if (k < 0) { - OP_LOGE(context, "k must be a non-negative number."); - return ge::GRAPH_FAILED; - } - - return ge::GRAPH_SUCCESS; -} - -static void ShowInputShapeInfo(gert::InferShapeContext *context, const gert::Shape *xShape, const int64_t k) -{ - OP_LOGD(context, "x shape is: %s.", TO_STRING(*xShape).c_str()); - OP_LOGD(context, "k is: %ld.", k); -} - -static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *yShape, - const gert::Shape *expertIdxShape, const gert::Shape *outShape) -{ - OP_LOGD(context, "y shape is: %s after infershape.", TO_STRING(*yShape).c_str()); - OP_LOGD(context, "expert_idx shape is: %s after infershape.", TO_STRING(*expertIdxShape).c_str()); - OP_LOGD(context, "out shape is: %s after infershape.", TO_STRING(*outShape).c_str()); -} - -static ge::graphStatus InferShape4MoeGatingTopK(gert::InferShapeContext *context) -{ - OP_LOGD(context, "Begin to do MoeGatingTopKInfershape."); - - // 获取输入shape - const gert::Shape *xShape = context->GetInputShape(0); - OP_CHECK_NULL_WITH_CONTEXT(context, xShape); - gert::Shape *yShape = context->GetOutputShape(0); - OP_CHECK_NULL_WITH_CONTEXT(context, yShape); - gert::Shape *expertIdxShape = context->GetOutputShape(1); - OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape); - gert::Shape *outShape = context->GetOutputShape(2); - OP_CHECK_NULL_WITH_CONTEXT(context, outShape); - - // 获取attr - auto attrs = context->GetAttrs(); - OP_CHECK_NULL_WITH_CONTEXT(context, attrs); - const int64_t *kPtr = attrs->GetAttrPointer(0); - OP_CHECK_NULL_WITH_CONTEXT(context, kPtr); - const int64_t k = *kPtr; - ShowInputShapeInfo(context, xShape, k); - - // 参数校验 - if (CheckInputDimsAndAttr(context, xShape, k) != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - - if (CheckInputShape(context, xShape) != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - - int64_t rows = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(0); - int64_t expertNum = xShape->GetDimNum() == 1U ? NEG_ONE : xShape->GetDim(1); - - yShape->SetDimNum(DIM_TWO); - yShape->SetDim(0U, rows); - yShape->SetDim(1U, k); - - expertIdxShape->SetDimNum(DIM_TWO); - expertIdxShape->SetDim(0U, rows); - expertIdxShape->SetDim(1U, k); - - outShape->SetDimNum(DIM_TWO); - outShape->SetDim(0U, rows); - outShape->SetDim(1U, expertNum); - - ShowOutputShapeInfo(context, yShape, expertIdxShape, outShape); - OP_LOGD(context, "End to do MoeGatingTopKInfershape."); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus InferDataType4MoeGatingTopK(gert::InferDataTypeContext *context) -{ - OP_LOGD(context, "Begin to do MoeGatingTopKInferDataType."); - auto xDtype = context->GetInputDataType(0); - context->SetOutputDataType(Y_INDEX, xDtype); - context->SetOutputDataType(EXPERT_IDX_INDEX, ge::DT_INT32); - context->SetOutputDataType(OUT_INDEX, ge::DT_FLOAT); - OP_LOGD(context, "End to do MoeGatingTopKInferDataType."); - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_INFERSHAPE(MoeGatingTopK).InferShape(InferShape4MoeGatingTopK).InferDataType(InferDataType4MoeGatingTopK); -} // namespace ops diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp deleted file mode 100644 index f10adf71..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.cpp +++ /dev/null @@ -1,15 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_proto.h - * \brief - */ -#include "moe_gating_top_k_proto.h" \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h deleted file mode 100644 index 068acdfc..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_proto.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_proto.h - * \brief - */ -#ifndef OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ -#define OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ - -#include "graph/operator_reg.h" - -namespace ge { - -/** - * @brief Compute renorm(sigmoid) and topk for moe input. - * - * @par Inputs: - * @li x: A 2D tensor which moe gating topk is applied, The shape is: (B*S, E), format supports ND, and data type must be float16, float or bfloat16. E(Expert num) can not be greater than 2048. E(Expert num) should be divisible by group_count. - * @li bias: A 1D tensor which is "bias" in moe gating topk. The shape is: (E), format supports ND, and data type must be the same as that of x. - * - * @par Outputs: - * @li y: A 2D tensor which is the topk value result of moe gating topk, format supports ND, and data type must be the same as that of x. - The size of the non-1 axis must be the same as that of the corresponding axis of x. - The size of the -1 axis must be the same as that of k. - * @li expert_idx: A 2D tensor which is the topk index result of moe gating topk, format supports ND, and data type must be int. The shape must be the same as that of y. - * @li out: A 2D tensor which is the renorm result of moe gating topk, format supports ND, and data type must be float. The shape must be the same as that of x. - * - * @par Attributes: - * @li k: A required attribute of type int. The value must greater than 0 and less than or equal to expert_num / group_count * k_group, idicating the topk value. - * @li k_group: An optional attribute of type int. It can not be less than 1, and can not be greater than group_count, indicating the topk group value. The default value is 1. - * @li group_count: An optional attribute of type int. It can not be less than 1, indicating the group count. The group_count * align_32(expert_num / group_count) can not be greater than 2048. The default value is 1. - * @li group_select_mode: An optional attribute of type int. 0 indicating that sort group by max values, 1 indicating that sort group by sum of top-2 values. The default value is 0. - * @li renorm: An optional attribute of type int. It can only be 0 now, indicating that norm firstly and then topk. The default value is 0. - * @li norm_type: An optional attribute of type int. 0 indicating that the softmax function is used, 1 indicating that the sigmoid function is used. The default value is 0. - * @li out_flag: An optional attribute of type bool. true indicating that has renorm output, false indicating that does not have renorm output. The default value is false. - * @li routed_scaling_factor: An optional attribute of type float, indicating the routed_scaling_factor coefficient in use. The default value is 1.0. - * @li eps: An optional attribute of type float, indicating the eps coefficient in use. The default value is 1e-20. - */ -REG_OP(MoeGatingTopK) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) - .OUTPUT(expert_idx, TensorType({DT_INT32})) - .OUTPUT(out, TensorType({DT_FLOAT})) - .REQUIRED_ATTR(k, Int) - .ATTR(k_group, Int, 1) - .ATTR(group_count, Int, 1) - .ATTR(group_select_mode, Int, 0) - .ATTR(renorm, Int, 0) - .ATTR(norm_type, Int, 0) - .ATTR(out_flag, Bool, false) - .ATTR(routed_scaling_factor, Float, 1.0) - .ATTR(eps, Float, 1e-20f) - .OP_END_FACTORY_REG(MoeGatingTopK) - -} // namespace ge - -#endif // OPS_OP_PROTO_INC_MOEGATINGTOPK_H_ \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp deleted file mode 100644 index affc22af..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.cpp +++ /dev/null @@ -1,580 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/* ! - * \file moe_gating_top_k_tiling.cpp - * \brief - */ -#include -#include "register/op_def_registry.h" -#include "exe_graph/runtime/infer_shape_context.h" -#include "register/op_impl_registry.h" -#include "../tiling_base/tiling_base.h" -#include "../tiling_base/tiling_templates_registry.h" -#include "platform/platform_info.h" - - - -#include "error_log.h" -#include "moe_gating_top_k_tiling.h" - -// 放在文件顶部,或单独头文件中 -#ifndef CEIL_ALIGN -#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align)) -#endif - -#ifndef CEIL_DIV -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) -#endif -namespace optiling { -const static int64_t GROUP_SELECT_MODE_MAX = 0; -const static int64_t GROUP_SELECT_MODE_SUM = 1; -const static int64_t RENORM_NO = 0; -const static int64_t RENORM_L1 = 1; -const static int64_t NORM_TYPE_SOFTMAX = 0; -const static int64_t NORM_TYPE_SIGMOID = 1; -const static int64_t OUT_FLAG_FALSE = 0; -const static int64_t OUT_FLAG_TRUE = 1; -const static size_t X_INPUT_DIMS = 2; -const static size_t BIAS_INPUT_DIMS = 1; -const static size_t Y_OUTPUT_DIMS = 2; -const static size_t EXPERT_IDX_OUTPUY_DIMS = 2; -const static size_t OUT_OUTPUT_DIMS = 2; -const static int64_t MAX_EXPERT_COUNT = 2048; - -const static int64_t X_INPUT_INDEX = 0; -const static int64_t BIAS_INPUT_INDEX = 1; -const static int64_t Y_OUTPUT_INDEX = 0; -const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1; -const static int64_t OUT_OUTPUT_INDEX = 2; -const static int64_t K_ATTR_INDEX = 0; -const static int64_t K_GROUP_ATTR_INDEX = 1; -const static int64_t GROUP_COUNT_ATTR_INDEX = 2; -const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3; -const static int64_t RENORM_ATTR_INDEX = 4; -const static int64_t NORM_TYPE_ATTR_INDEX = 5; -const static int64_t OUT_FLAG_ATTR_INDEX = 6; -const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7; -const static int64_t EPS_ATTR_INDEX = 8; -const static int64_t DEFAULT_WORKSPACE_SIZE = 16777216; // 预留16M空间 -const static uint32_t DATATYPESIZE_FLOAT = 4; -const static bool IS_LARGEST = true; -const static bool IS_INITINDEX = false; -const static bool IS_REUSESOURCE = false; -const static uint64_t WITH_GROUP_CONDITION = 1; -const static uint64_t WITHOUT_GROUP_CONDITION = 2; -const static uint64_t MAX_IN_GROUP_CONDITION = 3; -constexpr int32_t ROW_COUNT_PER_TASK = 1; - -const static uint64_t TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF = 0; -const static uint64_t TILING_KEY_WITHOUT_GROUP = 1; -const static uint64_t TILING_KEY_GENERALIZED = 2; - -inline static int64_t CeilLog4(int64_t x) -{ - return static_cast(std::ceil(std::log(x) / std::log(4))); // 4 for four -} - -class MoeGatingTopKTilingBase : public Ops::Transformer::OpTiling::TilingBaseClass { -public: - explicit MoeGatingTopKTilingBase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context) - { - Reset(); - } - ~MoeGatingTopKTilingBase() override = default; - - void Reset(gert::TilingContext *context) override - { - TilingBaseClass::Reset(context); - Reset(); - } - -protected: - bool IsCapable() override - { - return true; - } - // 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小 - ge::graphStatus GetPlatformInfo() override; - // 2、获取INPUT/OUTPUT/ATTR信息 - ge::graphStatus GetShapeAttrsInfo() override; - // 3、计算数据切分TilingData - ge::graphStatus DoOpTiling() override; - // 4、计算高阶API的TilingData - ge::graphStatus DoLibApiTiling() override; - // 5、计算TilingKey - uint64_t GetTilingKey() const override; - // 6、计算Workspace 大小 - ge::graphStatus GetWorkspaceSize() override; - // 7、保存Tiling数据 - ge::graphStatus PostTiling() override; - void Reset(); - -private: - ge::graphStatus CheckInputShape(); - ge::graphStatus CheckAttr(); - ge::graphStatus CheckOutShape(); - void SplitRows(); - void CalTmpBufUbSize(); - - const gert::Shape *xShape_ = nullptr; - const gert::Shape *biasShape_ = nullptr; - const gert::Shape *yShape_ = nullptr; - const gert::Shape *expertIdxShape_ = nullptr; - const gert::Shape *outShape_ = nullptr; - - int64_t rows_ = 0; - int64_t expertCount_ = 0; - int64_t addBias_ = 0; - - int64_t k_ = 0; - int64_t kGroup_ = 0; - int64_t groupCount_ = 0; - int64_t perGroupExpertCount_ = 0; - int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX; - int64_t renorm_ = RENORM_NO; - int64_t normType_ = NORM_TYPE_SOFTMAX; - int64_t outFlag_ = OUT_FLAG_FALSE; - float routedScalingFactor_ = 1.0; - float eps_ = 1e-20f; - - int64_t inputDtypeSize_; - const char *opName_ = ""; - MoeGatingTopKTilingData moeGatingTopKTilingData_; -}; - -ge::graphStatus MoeGatingTopKTilingBase::CheckInputShape() -{ - size_t xDimNum = xShape_->GetDimNum(); - - OP_CHECK_IF(xDimNum != X_INPUT_DIMS, - - OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS), - return ge::GRAPH_FAILED); - - // 通过输入获取rows 和 expertCount - rows_ = xShape_->GetDim(0); - expertCount_ = xShape_->GetDim(1); - - moeGatingTopKTilingData_.set_rowCount(rows_); - moeGatingTopKTilingData_.set_expertCount(expertCount_); - if (biasShape_ != nullptr) { - addBias_ = 1; - size_t biasDimNum = biasShape_->GetDimNum(); - OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS, - OP_LOGE(context_, "The dim number of bias is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS), - return ge::GRAPH_FAILED); - OP_CHECK_IF( - biasShape_->GetDim(0) != expertCount_, - OP_LOGE(context_, "The first dim of bias is: %ld, but should be %ld.", biasShape_->GetDim(0), expertCount_), - return ge::GRAPH_FAILED); - - } - moeGatingTopKTilingData_.set_addBias(addBias_); - - OP_CHECK_IF(k_ > expertCount_, - OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_), - return ge::GRAPH_FAILED); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::CheckAttr() -{ - OP_CHECK_IF( - expertCount_ > MAX_EXPERT_COUNT, - OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED); - - OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(kGroup_ > groupCount_, - OP_LOGE(context_, "k_group is: %ld, but should not greater than %ld.", kGroup_, groupCount_), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID, - OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_, - NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX, - OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_, - GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(renorm_ != RENORM_NO && renorm_ != RENORM_L1, - OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(expertCount_ % groupCount_ != 0, - OP_LOGE(context_, "Expert count : %ld is not divisible by k_group: %ld", expertCount_, groupCount_), - return ge::GRAPH_FAILED); - - perGroupExpertCount_ = expertCount_ / groupCount_; - - OP_LOGI(context_, "perGroupExpertCount_: %ld", perGroupExpertCount_); - - OP_CHECK_IF(perGroupExpertCount_ < 1, - OP_LOGE(context_, "group expert count is: %ld, but should be greater than 1.", perGroupExpertCount_), - return ge::GRAPH_FAILED); - OP_CHECK_IF( - groupSelectMode_ == GROUP_SELECT_MODE_SUM && perGroupExpertCount_ < 2, - OP_LOGE(context_, - "group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.", - perGroupExpertCount_, groupSelectMode_), - return ge::GRAPH_FAILED); - OP_CHECK_IF(k_ > kGroup_ * perGroupExpertCount_, - OP_LOGE(context_, "k is: %ld, but should be smaller than %ld.", k_, kGroup_ * perGroupExpertCount_), - return ge::GRAPH_FAILED); - int64_t groupExpertCountAlign = CEIL_ALIGN(perGroupExpertCount_, 32L); - OP_LOGI(context_, "333groupExpertCountAlign: %ld", groupExpertCountAlign); - if (groupCount_ != 1 && groupCount_ != expertCount_ && kGroup_ != groupCount_) { - // 分组场景下才需要校验对齐后的数量 - OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT, - OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.", - groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT), - return ge::GRAPH_FAILED); - } - - moeGatingTopKTilingData_.set_perGroupExpertCount(perGroupExpertCount_); - moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::GetShapeAttrsInfo() -{ - opName_ = context_->GetNodeName(); - // 获取输入shape信息 - OP_LOGI(context_, "111GetShapeAttrsInfo: opName = %s", opName_); - auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr); - xShape_ = &xShapePtr->GetStorageShape(); - OP_LOGI(context_, "112xShape: %s", xShape_->ToString().c_str()); - - auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX); - biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape(); - if (biasShape_ != nullptr) { - OP_LOGI(context_, "113biasShape: %s", biasShape_->ToString().c_str()); - } - // 获取输出shape - auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr); - yShape_ = &yShapePtr->GetStorageShape(); - OP_LOGI(context_, "115yShape: %s", yShape_->ToString().c_str()); - auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr); - expertIdxShape_ = &expertIdxPtr->GetStorageShape(); - OP_LOGI(context_, "116expertIdxShape: %s", expertIdxShape_->ToString().c_str()); - auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, outPtr); - outShape_ = &outPtr->GetStorageShape(); - if (outShape_ != nullptr) { - OP_LOGI(context_, "117outShape: %s", outShape_->ToString().c_str()); - } - - auto x = context_->GetInputDesc(X_INPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, x); - auto xDtype = x->GetDataType(); - OP_CHECK_IF( - (xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16), - OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.", - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - - if (biasShapePtr != nullptr) { - auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType(); - OP_LOGI(context_, "118bias dtype: %s", ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str()); - OP_CHECK_IF((biasDtype != xDtype), - OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.", - ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(), - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - } - - auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc); - auto yDtype = yDesc->GetDataType(); - OP_LOGI(context_, "119y dtype: %s", ge::TypeUtils::DataTypeToSerialString(yDtype).c_str()); - OP_CHECK_IF((yDtype != xDtype), - OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.", - ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(), - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - - auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc); - auto expertIdDtype = expertIdDesc->GetDataType(); - OP_LOGI(context_, "120expertId dtype: %s", ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()); - OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32), - OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.", - ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()), - return ge::GRAPH_FAILED); - - auto normOutDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, normOutDesc); - auto normOutDtype = normOutDesc->GetDataType(); - OP_CHECK_IF((normOutDtype != ge::DataType::DT_FLOAT), - OP_LOGE(context_, "norm out dtype %s error, only supports float. please check.", - ge::TypeUtils::DataTypeToSerialString(normOutDtype).c_str()), - return ge::GRAPH_FAILED); - - // 获取属性 - auto attrs = context_->GetAttrs(); - OP_CHECK_NULL_WITH_CONTEXT(context_, attrs); - - const int64_t *kPtr = attrs->GetAttrPointer(K_ATTR_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr); - k_ = *kPtr; - OP_LOGI(context_, "Attr k is: %ld", k_); - moeGatingTopKTilingData_.set_k(k_); - - OP_LOGI(context_, "Attr k is: %ld ", k_); - - const int64_t *kGroupPtr = attrs->GetAttrPointer(K_GROUP_ATTR_INDEX); - if (kGroupPtr != nullptr) { - kGroup_ = *kGroupPtr; - OP_LOGI(context_, "Attr k_group is: %ld", kGroup_); - moeGatingTopKTilingData_.set_kGroup(kGroup_); - } - OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_); - - const int64_t *groupCountPtr = attrs->GetAttrPointer(GROUP_COUNT_ATTR_INDEX); - if (groupCountPtr != nullptr) { - groupCount_ = *groupCountPtr; - OP_LOGI(context_, "Attr group_count is: %ld", groupCount_); - moeGatingTopKTilingData_.set_groupCount(groupCount_); - } - OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_); - - const int64_t *groupSelectModePtr = attrs->GetAttrPointer(GROUP_SELECT_MODE_ATTR_INDEX); - if (groupSelectModePtr != nullptr) { - groupSelectMode_ = *groupSelectModePtr; - OP_LOGI(context_, "Attr group_select_mode is: %ld", groupSelectMode_); - moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_); - } - OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_); - - const int64_t *renormPtr = attrs->GetAttrPointer(RENORM_ATTR_INDEX); - if (renormPtr != nullptr) { - renorm_ = *renormPtr; - OP_LOGI(context_, "Attr renorm is: %ld", renorm_); - moeGatingTopKTilingData_.set_renorm(renorm_); - } - OP_LOGI(context_, "Attr renorm is: %ld ", renorm_); - - const int64_t *normTypePtr = attrs->GetAttrPointer(NORM_TYPE_ATTR_INDEX); - if (normTypePtr != nullptr) { - normType_ = *normTypePtr; - OP_LOGI(context_, "Attr norm_type is: %ld", normType_); - moeGatingTopKTilingData_.set_normType(normType_); - } - OP_LOGI(context_, "Attr norm_type is: %ld ", normType_); - - const bool *outFlagPtr = attrs->GetAttrPointer(OUT_FLAG_ATTR_INDEX); - if (outFlagPtr != nullptr) { - outFlag_ = (*outFlagPtr) ? 1 : 0; - OP_LOGI(context_, "Attr out_flag is: %ld", outFlag_); - moeGatingTopKTilingData_.set_outFlag(outFlag_); - } - OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_); - - const float *routedScalingFactorPtr = attrs->GetAttrPointer(ROUTED_SCALING_FACTOR_ATTR_INDEX); - if (routedScalingFactorPtr != nullptr) { - routedScalingFactor_ = *routedScalingFactorPtr; - OP_LOGI(context_, "Attr routed_scaling_factor is: %f", routedScalingFactor_); - moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_); - } - OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_); - - const float *epsPtr = attrs->GetAttrPointer(EPS_ATTR_INDEX); - if (epsPtr != nullptr) { - eps_ = *epsPtr; - OP_LOGI(context_, "Attr eps is: %f", eps_); - moeGatingTopKTilingData_.set_eps(eps_); - } - OP_LOGI(context_, "Attr eps is: %f ", eps_); - - inputDtypeSize_ = static_cast(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType())); - OP_LOGI(context_, "inputDtypeSize_: %ld", inputDtypeSize_); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::GetPlatformInfo() - -{ - auto platformInfo = context_->GetPlatformInfo(); - OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED); - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); - aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv(); - uint64_t ubSizePlatForm; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); - aicoreParams_.ubSize = ubSizePlatForm; - OP_LOGI(context_, "GetPlatformInfo: blockDim = %ld, ubSize = %lu", aicoreParams_.blockDim, aicoreParams_.ubSize); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::CheckOutShape() -{ - OP_LOGI(context_, "555CheckOutShape: yShape_: %s, xShape_: %s", yShape_->ToString().c_str(), xShape_->ToString().c_str()); - OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(), - xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.", - expertIdxShape_->GetDimNum(), xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - if (outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.", - outShape_->GetDimNum(), xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - } - - OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0), - xShape_->GetDim(0)), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.", - expertIdxShape_->GetDim(0), xShape_->GetDim(0)), - return ge::GRAPH_FAILED); - if (outFlag_ && outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.", - outShape_->GetDim(0), outShape_->GetDim(0)), - return ge::GRAPH_FAILED); - } - - OP_CHECK_IF((yShape_->GetDim(1) != k_), - OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_), - OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_), - return ge::GRAPH_FAILED); - if (outFlag_ && outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)), - OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1), - xShape_->GetDim(1)), - return ge::GRAPH_FAILED); - } - return ge::GRAPH_SUCCESS; -} - -void MoeGatingTopKTilingBase::SplitRows() -{ - int64_t perCoreRows = CEIL_DIV(rows_, static_cast(aicoreParams_.blockDim)); - int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows); - // perCoreRows cannot be 0 - int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows; - moeGatingTopKTilingData_.set_needCoreNum(needCoreNum); - moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows); - moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows); - int64_t vmsCount = CeilLog4(CEIL_DIV(kGroup_, 4L)); - OP_LOGI(context_, "vms count is: %ld", vmsCount); - moeGatingTopKTilingData_.set_vmsCount(vmsCount); // 需要归并的轮数 -} - -void MoeGatingTopKTilingBase::CalTmpBufUbSize() - -{ - - std::vector shape_vec = {expertCount_}; - ge::Shape shape(shape_vec); - uint32_t maxValue = 0; - uint32_t minValue = 0; - AscendC::GetSigmoidMaxMinTmpSize(shape, sizeof(float), false, maxValue, minValue); - - int64_t indexTmpBuf = (expertCount_ + 31) / 32 * 32 * static_cast(sizeof(float)); - moeGatingTopKTilingData_.set_calTmpBufUbSize(std::max(indexTmpBuf, static_cast(minValue))); -} - -ge::graphStatus MoeGatingTopKTilingBase::DoOpTiling() -{ - - OP_LOGI(context_, "DoOpTiling: start"); - auto ret = CheckInputShape(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - ret = CheckOutShape(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - ret = CheckAttr(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - CalTmpBufUbSize(); - SplitRows(); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::DoLibApiTiling() -{ - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::GetWorkspaceSize() -{ - // 计算workspace大小 - workspaceSize_ = DEFAULT_WORKSPACE_SIZE; - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingBase::PostTiling() -{ - context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum()); - size_t *currentWorkspace = context_->GetWorkspaceSizes(1); - currentWorkspace[0] = workspaceSize_; - moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), - context_->GetRawTilingData()->GetCapacity()); - context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize()); - return ge::GRAPH_SUCCESS; -} - -uint64_t MoeGatingTopKTilingBase::GetTilingKey() const -{ - // DeepSeekV3排序对齐高性能场景 - if (expertCount_ == 256 && groupCount_ == 8 && kGroup_ == 4 && k_ <= 32 && addBias_ && - groupSelectMode_ == GROUP_SELECT_MODE_SUM && renorm_ == RENORM_NO && normType_ == NORM_TYPE_SIGMOID && - !outFlag_) { - // DeepSeekV3排序对齐高性能场景 - return TILING_KEY_EXPERTNUM_GROUPNUM_ALIGN_HIGH_PERF; - } else if (groupCount_ == 1 || groupCount_ == expertCount_ || kGroup_ == groupCount_) { - /** - * 不分组场景: - * 1. 分组数为 1 - * 2. 分组数等于专家数(每个组只有一个专家) - * 3. 选择所有组 - */ - return TILING_KEY_WITHOUT_GROUP; - } else { - return TILING_KEY_GENERALIZED; - } -} - -void MoeGatingTopKTilingBase::Reset() -{ - opName_ = nullptr; - return; -} - -REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingBase, 2000); -} // namespace optiling diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h deleted file mode 100644 index 6ac4aa46..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_tiling.h - * \brief - */ - -#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H -#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H - -#include -#include -#include -#include - - -#include "../tiling_base/tiling_base.h" -#include "../tiling_base/tiling_templates_registry.h" -#include "register/op_def_registry.h" -#include "register/op_impl_registry.h" -#include "register/tilingdata_base.h" -#include "tiling/tiling_api.h" -#include "error_log.h" - -#include "register/op_impl_registry.h" -#include "platform/platform_infos_def.h" -#include "math_util.h" -//#include "util/extern_math_util.h" - -namespace optiling { -BEGIN_TILING_DATA_DEF(MoeGatingTopKTilingData) -TILING_DATA_FIELD_DEF(int64_t, needCoreNum); -TILING_DATA_FIELD_DEF(int64_t, rowCount); -TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount); -TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount); -TILING_DATA_FIELD_DEF(int64_t, expertCount); -TILING_DATA_FIELD_DEF(int64_t, addBias); -TILING_DATA_FIELD_DEF(int64_t, k); -TILING_DATA_FIELD_DEF(int64_t, kGroup); -TILING_DATA_FIELD_DEF(int64_t, groupCount); -TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount); -TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign); -TILING_DATA_FIELD_DEF(int64_t, groupSelectMode); -TILING_DATA_FIELD_DEF(int64_t, renorm); -TILING_DATA_FIELD_DEF(int64_t, normType); -TILING_DATA_FIELD_DEF(int64_t, outFlag); -TILING_DATA_FIELD_DEF(int64_t, vmsCount); -TILING_DATA_FIELD_DEF(float, routedScalingFactor); -TILING_DATA_FIELD_DEF(float, eps); -TILING_DATA_FIELD_DEF(int64_t, calTmpBufUbSize); -END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(MoeGatingTopK, MoeGatingTopKTilingData) - -BEGIN_TILING_DATA_DEF(MoeGatingTopKRegbaseTilingData) -TILING_DATA_FIELD_DEF(int64_t, needCoreNum); -TILING_DATA_FIELD_DEF(int64_t, rowCount); -TILING_DATA_FIELD_DEF(int64_t, perCoreRowCount); -TILING_DATA_FIELD_DEF(int64_t, lastCoreRowCount); -TILING_DATA_FIELD_DEF(int64_t, expertCount); -TILING_DATA_FIELD_DEF(int64_t, addBias); -TILING_DATA_FIELD_DEF(int64_t, k); -TILING_DATA_FIELD_DEF(int64_t, kGroup); -TILING_DATA_FIELD_DEF(int64_t, groupCount); -TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCount); -TILING_DATA_FIELD_DEF(int64_t, perGroupExpertCountAlign); -TILING_DATA_FIELD_DEF(int64_t, groupSelectMode); -TILING_DATA_FIELD_DEF(int64_t, renorm); -TILING_DATA_FIELD_DEF(int64_t, normType); -TILING_DATA_FIELD_DEF(int64_t, outFlag); -TILING_DATA_FIELD_DEF(int64_t, vmsCount); -TILING_DATA_FIELD_DEF(float, routedScalingFactor); -TILING_DATA_FIELD_DEF(float, eps); -TILING_DATA_FIELD_DEF_STRUCT(SoftMaxTiling, softmaxTilingData); -END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(MoeGatingTopK_10000, MoeGatingTopKRegbaseTilingData) -struct MoeGatingTopKCompileInfo {}; -} // namespace optiling -#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_GATING_TOP_K_H diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp deleted file mode 100644 index 84d61c6a..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_arch35.cpp +++ /dev/null @@ -1,521 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/* ! - * \file moe_gating_top_k_tiling_arch35.cpp - * \brief - */ - -#include "error_log.h" -#include "moe_gating_top_k_tiling.h" -#include "register/op_def_registry.h" -#include "platform/platform_info.h" -#include "../tiling_base/tiling_base.h" -#include "../tiling_base/tiling_templates_registry.h" - -#ifndef CEIL_ALIGN -#define CEIL_ALIGN(val, align) ((((val) + (align) - 1) / (align)) * (align)) -#endif -#ifndef CEIL_DIV -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) -#endif -namespace optiling { -const static uint64_t MOE_GATING_TOP_K_REGBASE_TILING_KEY = 10000; - -const static int64_t GROUP_SELECT_MODE_MAX = 0; -const static int64_t GROUP_SELECT_MODE_SUM = 1; -const static int64_t RENORM_NO = 0; -const static int64_t RENORM_L1 = 1; -const static int64_t NORM_TYPE_SOFTMAX = 0; -const static int64_t NORM_TYPE_SIGMOID = 1; -const static int64_t OUT_FLAG_FALSE = 0; -const static int64_t OUT_FLAG_TRUE = 1; -const static size_t X_INPUT_DIMS = 2; -const static size_t BIAS_INPUT_DIMS = 1; -const static size_t Y_OUTPUT_DIMS = 2; -const static size_t EXPERT_IDX_OUTPUY_DIMS = 2; -const static size_t OUT_OUTPUT_DIMS = 2; -const static int64_t MAX_EXPERT_COUNT = 2048; - -const static int64_t X_INPUT_INDEX = 0; -const static int64_t BIAS_INPUT_INDEX = 1; -const static int64_t Y_OUTPUT_INDEX = 0; -const static int64_t EXPERT_IDX_OUTPUT_INDEX = 1; -const static int64_t OUT_OUTPUT_INDEX = 2; -const static int64_t K_ATTR_INDEX = 0; -const static int64_t K_GROUP_ATTR_INDEX = 1; -const static int64_t GROUP_COUNT_ATTR_INDEX = 2; -const static int64_t GROUP_SELECT_MODE_ATTR_INDEX = 3; -const static int64_t RENORM_ATTR_INDEX = 4; -const static int64_t MRGSORT_SIZE = 4; -const static int64_t NORM_TYPE_ATTR_INDEX = 5; -const static int64_t OUT_FLAG_ATTR_INDEX = 6; -const static int64_t ROUTED_SCALING_FACTOR_ATTR_INDEX = 7; -const static int64_t EPS_ATTR_INDEX = 8; -const static int64_t DEFAULT_WORKSPACE_SIZE = static_cast(16 * 1024 * 1024); // 预留16M空间 - - -class MoeGatingTopKTilingRegbase : public Ops::Transformer::OpTiling::TilingBaseClass { -public: - explicit MoeGatingTopKTilingRegbase(gert::TilingContext *context) : Ops::Transformer::OpTiling::TilingBaseClass(context) - { - Reset(); - } - ~MoeGatingTopKTilingRegbase() override = default; - - void Reset(gert::TilingContext *context) override - { - TilingBaseClass::Reset(context); - Reset(); - } - -protected: - bool IsCapable() override - { - if (socVersion != platform_ascendc::SocVersion::ASCEND910_95) { - return false; - } - return true; - } - // 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小 - ge::graphStatus GetPlatformInfo() override; - // 2、获取INPUT/OUTPUT/ATTR信息 - ge::graphStatus GetShapeAttrsInfo() override; - // 3、计算数据切分TilingData - ge::graphStatus DoOpTiling() override; - // 4、计算高阶API的TilingData - ge::graphStatus DoLibApiTiling() override; - // 5、计算TilingKey - uint64_t GetTilingKey() const override; - // 6、计算Workspace 大小 - ge::graphStatus GetWorkspaceSize() override; - // 7、保存Tiling数据 - ge::graphStatus PostTiling() override; - void Reset(); - -private: - ge::graphStatus CheckInputShape(); - ge::graphStatus CheckAttr(); - ge::graphStatus CheckOutShape(); - void CalTmpBufUbSize(); - void SplitRows(); - void Tiling4GatherOutComputeSplitK(); - - const gert::Shape *xShape_ = nullptr; - const gert::Shape *biasShape_ = nullptr; - const gert::Shape *yShape_ = nullptr; - const gert::Shape *expertIdxShape_ = nullptr; - const gert::Shape *outShape_ = nullptr; - - int64_t rows_; - int64_t expertCount_; - int64_t addBias_ = 0; - - int64_t k_; - int64_t kGroup_ = 1; - int64_t groupCount_ = 1; - int64_t groupSelectMode_ = GROUP_SELECT_MODE_MAX; - int64_t renorm_ = RENORM_NO; - int64_t normType_ = NORM_TYPE_SOFTMAX; - int64_t outFlag_ = OUT_FLAG_FALSE; - float routedScalingFactor_ = 1.0; - float eps_ = 1e-20f; - - int64_t inputDtypeSize_; - const char *opName_ = ""; - MoeGatingTopKRegbaseTilingData moeGatingTopKTilingData_; - platform_ascendc::SocVersion socVersion; -}; - -ge::graphStatus MoeGatingTopKTilingRegbase::CheckInputShape() -{ - size_t xDimNum = xShape_->GetDimNum(); - OP_CHECK_IF(xDimNum != X_INPUT_DIMS, - OP_LOGE(context_, "The dim number of x is: %zu, but should be %zu.", xDimNum, X_INPUT_DIMS), - return ge::GRAPH_FAILED); - - // 通过输入获取rows 和 expertCount - rows_ = xShape_->GetDim(0); - expertCount_ = xShape_->GetDim(1); - moeGatingTopKTilingData_.set_rowCount(rows_); - moeGatingTopKTilingData_.set_expertCount(expertCount_); - OP_CHECK_IF( - expertCount_ > MAX_EXPERT_COUNT, - OP_LOGE(context_, "expert count is: %ld, but should not greater than %ld.", expertCount_, MAX_EXPERT_COUNT), - return ge::GRAPH_FAILED); - - if (biasShape_ != nullptr) { - addBias_ = 1; - size_t biasDimNum = biasShape_->GetDimNum(); - OP_CHECK_IF(biasDimNum != BIAS_INPUT_DIMS, - OP_LOGE(context_, "The number of bias dim is: %zu, but should be %zu.", biasDimNum, BIAS_INPUT_DIMS), - return ge::GRAPH_FAILED); - OP_CHECK_IF(biasShape_->GetDim(0) != expertCount_, - OP_LOGE(context_, "The first dim of bias is: %ld, but should be expert num: %ld.", - biasShape_->GetDim(0), expertCount_), - return ge::GRAPH_FAILED); - } - moeGatingTopKTilingData_.set_addBias(addBias_); - - OP_CHECK_IF(k_ > expertCount_, - OP_LOGE(context_, "k is: %ld, expert num is: %ld, k cannot be greater than expert num.", k_, expertCount_), - return ge::GRAPH_FAILED); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::CheckAttr() -{ - OP_CHECK_IF(k_ <= 0, OP_LOGE(context_, "k is: %ld, but should be greater than 0.", k_), return ge::GRAPH_FAILED); - OP_CHECK_IF(kGroup_ <= 0, OP_LOGE(context_, "k_group is: %ld, but should be greater than 0.", kGroup_), - return ge::GRAPH_FAILED); - OP_CHECK_IF(groupCount_ <= 0, OP_LOGE(context_, "group_count is: %ld, but should be greater than 0.", groupCount_), - return ge::GRAPH_FAILED); - OP_CHECK_IF(expertCount_ % groupCount_ != 0, - OP_LOGE(context_, "expert num : %ld is not divisible by group_count: %ld", expertCount_, groupCount_), - return ge::GRAPH_FAILED); - OP_CHECK_IF(kGroup_ > groupCount_, - OP_LOGE(context_, "k_group is: %ld, but should not greater than group_count: %ld", kGroup_, groupCount_), - return ge::GRAPH_FAILED); - OP_CHECK_IF(groupCount_ == expertCount_ && kGroup_ < k_, - OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.", - kGroup_, k_), - return ge::GRAPH_FAILED); - - if (kGroup_ == groupCount_ || groupCount_ == expertCount_) { - kGroup_ = 1; - groupCount_ = 1; - } - moeGatingTopKTilingData_.set_kGroup(kGroup_); - moeGatingTopKTilingData_.set_groupCount(groupCount_); - int64_t groupExpertCount = expertCount_ / groupCount_; - int64_t groupExpertCountAlign = CEIL_ALIGN(groupExpertCount, 32L); - moeGatingTopKTilingData_.set_perGroupExpertCount(expertCount_ / groupCount_); - moeGatingTopKTilingData_.set_perGroupExpertCountAlign(groupExpertCountAlign); - - OP_CHECK_IF(groupCount_ * groupExpertCountAlign > MAX_EXPERT_COUNT, - OP_LOGE(context_, "group count * group expert count align is: %ld, but should not greater than %ld.", - groupCount_ * groupExpertCountAlign, MAX_EXPERT_COUNT), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(kGroup_ * groupExpertCount < k_, - OP_LOGE(context_, "k_group * group expert count is: %ld, but it must be greater than or equal to k: %ld.", - kGroup_ * groupExpertCount, k_), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(groupExpertCount < 1, - OP_LOGE(context_, "per group expert count is: %ld, but should be greater than 0.", groupExpertCount), - return ge::GRAPH_FAILED); - OP_CHECK_IF( - groupSelectMode_ != GROUP_SELECT_MODE_SUM && groupSelectMode_ != GROUP_SELECT_MODE_MAX, - OP_LOGE(context_, "group select mode is: %ld, but currently only support %ld and %ld.", groupSelectMode_, - GROUP_SELECT_MODE_SUM, GROUP_SELECT_MODE_MAX), - return ge::GRAPH_FAILED); - OP_CHECK_IF(groupSelectMode_ == GROUP_SELECT_MODE_SUM && groupExpertCount < 2, - OP_LOGE(context_, - "group expert count is: %ld, if group select mode is: %ld, group expert count should be greater than 1.", - groupExpertCount, groupSelectMode_), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(renorm_ != RENORM_NO, - OP_LOGE(context_, "renorm is: %ld, but currently only support %ld.", renorm_, RENORM_NO), - return ge::GRAPH_FAILED); - - OP_CHECK_IF(normType_ != NORM_TYPE_SOFTMAX && normType_ != NORM_TYPE_SIGMOID, - OP_LOGE(context_, "norm type is: %ld, but currently only support %ld and %ld.", normType_, - NORM_TYPE_SOFTMAX, NORM_TYPE_SIGMOID), - return ge::GRAPH_FAILED); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::GetShapeAttrsInfo() -{ - opName_ = context_->GetNodeName(); - // 获取输入shape信息 - auto xShapePtr = context_->GetInputShape(X_INPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr); - xShape_ = &xShapePtr->GetStorageShape(); - auto biasShapePtr = context_->GetOptionalInputShape(BIAS_INPUT_INDEX); - biasShape_ = biasShapePtr == nullptr ? nullptr : &biasShapePtr->GetStorageShape(); - - // 获取输出shape - auto yShapePtr = context_->GetOutputShape(Y_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, yShapePtr); - yShape_ = &yShapePtr->GetStorageShape(); - auto expertIdxPtr = context_->GetOutputShape(EXPERT_IDX_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxPtr); - expertIdxShape_ = &expertIdxPtr->GetStorageShape(); - auto outPtr = context_->GetOutputShape(OUT_OUTPUT_INDEX); - if (outPtr != nullptr) { - outShape_ = &outPtr->GetStorageShape(); - } - - auto x = context_->GetInputDesc(X_INPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, x); - auto xDtype = x->GetDataType(); - OP_CHECK_IF( - (xDtype != ge::DataType::DT_FLOAT && xDtype != ge::DataType::DT_FLOAT16 && xDtype != ge::DataType::DT_BF16), - OP_LOGE(context_, "x dtype %s error, only supports float32, half, bf16. please check.", - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - - if (biasShapePtr != nullptr) { - auto biasDtype = context_->GetOptionalInputDesc(BIAS_INPUT_INDEX)->GetDataType(); - OP_CHECK_IF((biasDtype != xDtype), - OP_LOGE(context_, "bias dtype %s not equal x dtype %s, please check.", - ge::TypeUtils::DataTypeToSerialString(biasDtype).c_str(), - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - } - - auto yDesc = context_->GetOutputDesc(Y_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, yDesc); - auto yDtype = yDesc->GetDataType(); - OP_CHECK_IF((yDtype != xDtype), - OP_LOGE(context_, "y out dtype %s must be the same with x dtype %s.", - ge::TypeUtils::DataTypeToSerialString(yDtype).c_str(), - ge::TypeUtils::DataTypeToSerialString(xDtype).c_str()), - return ge::GRAPH_FAILED); - - auto expertIdDesc = context_->GetOutputDesc(EXPERT_IDX_OUTPUT_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdDesc); - auto expertIdDtype = expertIdDesc->GetDataType(); - OP_CHECK_IF((expertIdDtype != ge::DataType::DT_INT32), - OP_LOGE(context_, "expertId out dtype %s error, only supports int32. please check.", - ge::TypeUtils::DataTypeToSerialString(expertIdDtype).c_str()), - return ge::GRAPH_FAILED); - - // 获取属性 - auto attrs = context_->GetAttrs(); - OP_CHECK_NULL_WITH_CONTEXT(context_, attrs); - - const int64_t *kPtr = attrs->GetAttrPointer(K_ATTR_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context_, kPtr); - k_ = *kPtr; - moeGatingTopKTilingData_.set_k(k_); - OP_LOGI(context_, "Attr k is: %ld ", k_); - - const int64_t *kGroupPtr = attrs->GetAttrPointer(K_GROUP_ATTR_INDEX); - if (kGroupPtr != nullptr) { - kGroup_ = *kGroupPtr; - } - OP_LOGI(context_, "Attr k_group is: %ld ", kGroup_); - - const int64_t *groupCountPtr = attrs->GetAttrPointer(GROUP_COUNT_ATTR_INDEX); - if (groupCountPtr != nullptr) { - groupCount_ = *groupCountPtr; - } - OP_LOGI(context_, "Attr group_count is: %ld ", groupCount_); - - const int64_t *groupSelectModePtr = attrs->GetAttrPointer(GROUP_SELECT_MODE_ATTR_INDEX); - if (groupSelectModePtr != nullptr) { - groupSelectMode_ = *groupSelectModePtr; - } - moeGatingTopKTilingData_.set_groupSelectMode(groupSelectMode_); - OP_LOGI(context_, "Attr group_select_mode is: %ld ", groupSelectMode_); - - const int64_t *renormPtr = attrs->GetAttrPointer(RENORM_ATTR_INDEX); - if (renormPtr != nullptr) { - renorm_ = *renormPtr; - } - moeGatingTopKTilingData_.set_renorm(renorm_); - OP_LOGI(context_, "Attr renorm is: %ld ", renorm_); - - const int64_t *normTypePtr = attrs->GetAttrPointer(NORM_TYPE_ATTR_INDEX); - if (normTypePtr != nullptr) { - normType_ = *normTypePtr; - } - moeGatingTopKTilingData_.set_normType(normType_); - OP_LOGI(context_, "Attr norm_type is: %ld ", normType_); - - const bool *outFlagPtr = attrs->GetAttrPointer(OUT_FLAG_ATTR_INDEX); - if (outFlagPtr != nullptr) { - outFlag_ = (*outFlagPtr) ? 1 : 0; - } - moeGatingTopKTilingData_.set_outFlag(outFlag_); - OP_LOGI(context_, "Attr out_flag is: %ld ", outFlag_); - - const float *routedScalingFactorPtr = attrs->GetAttrPointer(ROUTED_SCALING_FACTOR_ATTR_INDEX); - if (routedScalingFactorPtr != nullptr) { - routedScalingFactor_ = *routedScalingFactorPtr; - } - moeGatingTopKTilingData_.set_routedScalingFactor(routedScalingFactor_); - OP_LOGI(context_, "Attr routed_scaling_factor is: %f ", routedScalingFactor_); - - const float *epsPtr = attrs->GetAttrPointer(EPS_ATTR_INDEX); - if (epsPtr != nullptr) { - eps_ = *epsPtr; - } - moeGatingTopKTilingData_.set_eps(eps_); - OP_LOGI(context_, "Attr eps is: %f ", eps_); - - auto outDesc = context_->GetOutputDesc(OUT_OUTPUT_INDEX); - if (outFlag_ && outDesc != nullptr) { - auto outDtype = outDesc->GetDataType(); - OP_CHECK_IF((outDtype != ge::DataType::DT_FLOAT), - OP_LOGE(context_, "norm out dtype %s error, only supports float32. please check.", - ge::TypeUtils::DataTypeToSerialString(outDtype).c_str()), - return ge::GRAPH_FAILED); - } - - inputDtypeSize_ = static_cast(ge::GetSizeByDataType(context_->GetInputDesc(0)->GetDataType())); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::GetPlatformInfo() -{ - auto platformInfo = context_->GetPlatformInfo(); - OP_CHECK_IF(platformInfo == nullptr, OP_LOGE(context_, "fail to get platform info"), return ge::GRAPH_FAILED); - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); - aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv(); - socVersion = ascendcPlatform.GetSocVersion(); - uint64_t ubSizePlatForm; - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); - aicoreParams_.ubSize = ubSizePlatForm; - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::CheckOutShape() -{ - OP_CHECK_IF((yShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "y out shape num %zu and x shape num %zu not equal, please check.", yShape_->GetDimNum(), - xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "expertId out shape num %zu and x shape num %zu not equal, please check.", - expertIdxShape_->GetDimNum(), xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - if (outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDimNum() != xShape_->GetDimNum()), - OP_LOGE(context_, "norm out shape num %zu and x shape num %zu not equal, please check.", - outShape_->GetDimNum(), xShape_->GetDimNum()), - return ge::GRAPH_FAILED); - } - - OP_CHECK_IF((yShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "y out dim[0] %ld not euqal x dim[0] %ld, please check.", yShape_->GetDim(0), - xShape_->GetDim(0)), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "expertId out dim[0] %ld not euqal x dim[0] %ld, please check.", - expertIdxShape_->GetDim(0), xShape_->GetDim(0)), - return ge::GRAPH_FAILED); - if (outFlag_ && outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDim(0) != xShape_->GetDim(0)), - OP_LOGE(context_, "norm out dim[0] %ld and x dim[0] %ld not equal, please check.", - outShape_->GetDim(0), outShape_->GetDim(0)), - return ge::GRAPH_FAILED); - } - - OP_CHECK_IF((yShape_->GetDim(1) != k_), - OP_LOGE(context_, "y dim[1] %ld not euqal k %ld, please check.", yShape_->GetDim(1), k_), - return ge::GRAPH_FAILED); - OP_CHECK_IF((expertIdxShape_->GetDim(1) != k_), - OP_LOGE(context_, "expertId dim[1] %ld not euqal k %ld, please check.", expertIdxShape_->GetDim(1), k_), - return ge::GRAPH_FAILED); - if (outFlag_ && outShape_ != nullptr) { - OP_CHECK_IF((outShape_->GetDim(1) != xShape_->GetDim(1)), - OP_LOGE(context_, "normOut dim[1] %ld and x dim[1] %ld not equal, please check.", outShape_->GetDim(1), - xShape_->GetDim(1)), - return ge::GRAPH_FAILED); - } - return ge::GRAPH_SUCCESS; -} - -void MoeGatingTopKTilingRegbase::CalTmpBufUbSize() { - std::vector shape_vec = {groupCount_ * moeGatingTopKTilingData_.get_perGroupExpertCountAlign()}; - ge::Shape softmaxShape(shape_vec); - - uint32_t softmaxTmpSize = AscendC::GetSoftMaxMaxTmpSize(softmaxShape, sizeof(float), true); - AscendC::SoftMaxTilingFunc(softmaxShape, sizeof(float), softmaxTmpSize, moeGatingTopKTilingData_.softmaxTilingData); -} - -void MoeGatingTopKTilingRegbase::SplitRows() -{ - int64_t perCoreRows = CEIL_DIV(rows_, static_cast(aicoreParams_.blockDim)); - int64_t needCoreNum = CEIL_DIV(rows_, perCoreRows); - if (perCoreRows == 0) { - OP_LOGE(context_, "perCoreRows can't be 0."); - return; - } - int64_t lastCoreRows = rows_ % perCoreRows == 0 ? perCoreRows : rows_ % perCoreRows; - moeGatingTopKTilingData_.set_needCoreNum(needCoreNum); - moeGatingTopKTilingData_.set_perCoreRowCount(perCoreRows); - moeGatingTopKTilingData_.set_lastCoreRowCount(lastCoreRows); - - int64_t vmsCount = 0; - if (kGroup_ > MRGSORT_SIZE) { - int64_t index = MRGSORT_SIZE; - while (index < kGroup_) { - index = index * MRGSORT_SIZE; - vmsCount++; - } - } - moeGatingTopKTilingData_.set_vmsCount(vmsCount); -} - -ge::graphStatus MoeGatingTopKTilingRegbase::DoOpTiling() -{ - auto ret = CheckInputShape(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - ret = CheckAttr(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - ret = CheckOutShape(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - CalTmpBufUbSize(); - SplitRows(); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::DoLibApiTiling() -{ - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::GetWorkspaceSize() -{ - // 计算workspace大小 - workspaceSize_ = DEFAULT_WORKSPACE_SIZE; - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus MoeGatingTopKTilingRegbase::PostTiling() -{ - context_->SetBlockDim(moeGatingTopKTilingData_.get_needCoreNum()); - size_t *currentWorkspace = context_->GetWorkspaceSizes(1); - currentWorkspace[0] = workspaceSize_; - moeGatingTopKTilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), - context_->GetRawTilingData()->GetCapacity()); - context_->GetRawTilingData()->SetDataSize(moeGatingTopKTilingData_.GetDataSize()); - return ge::GRAPH_SUCCESS; -} - -uint64_t MoeGatingTopKTilingRegbase::GetTilingKey() const -{ - return MOE_GATING_TOP_K_REGBASE_TILING_KEY; -} - -void MoeGatingTopKTilingRegbase::Reset() -{ - opName_ = nullptr; - return; -} - -REGISTER_OPS_TILING_TEMPLATE(MoeGatingTopK, MoeGatingTopKTilingRegbase, 1000); -} // namespace optiling diff --git a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp b/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp deleted file mode 100644 index db9ce60d..00000000 --- a/csrc/moe_gating_top_k/op_host/moe_gating_top_k_tiling_base.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/* ! - * \file moe_gating_top_k_tiling_base.cpp - * \brief - */ -#include "moe_gating_top_k_tiling.h" -#include "register/op_def_registry.h" -#include "../tiling_base/tiling_base.h" -#include "../tiling_base/tiling_templates_registry.h" -#include "error_log.h" -#include "kernel_tiling/kernel_tiling.h" - -namespace optiling { -static ge::graphStatus TilingForMoeGatingTopK(gert::TilingContext *context) -{ - return Ops::Transformer::OpTiling::TilingRegistry::GetInstance().DoTilingImpl(context); -} - -static ge::graphStatus TilingPrepareForMoeGatingTopK(gert::TilingParseContext *context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -IMPL_OP_OPTILING(MoeGatingTopK) - .Tiling(TilingForMoeGatingTopK) - .TilingParse(TilingPrepareForMoeGatingTopK); - -} // namespace optiling \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/common.h b/csrc/moe_gating_top_k/op_kernel/common.h deleted file mode 100644 index 0847ee2b..00000000 --- a/csrc/moe_gating_top_k/op_kernel/common.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file common.h - * \brief - */ -#ifndef MOE_GATING_TOP_K_COMMON_H -#define MOE_GATING_TOP_K_COMMON_H - -#include "kernel_operator.h" - -namespace MoeGatingTopK { -using namespace AscendC; -const float MIN_FP32 = *(float *)(&F32_NEG_INF); -constexpr int32_t FLOAT32_NEG_INF = 0xFF800000; // -inf -2139095040 -constexpr int64_t ONE_REPEAT_SORT_NUM = 32; -constexpr int64_t BLOCK_BYTES = 32; -constexpr int64_t REPEAT_BYTES = 256; -constexpr int64_t REPEAT_BLOCKS = 8; - -constexpr int32_t CONSTANT_TWO = 2; -constexpr int32_t CONSTANT_THREE = 3; -constexpr int32_t CONSTANT_FOUR = 4; -constexpr int32_t CONSTANT_EIGHT = 8; - -constexpr int64_t MERGE_LIST_TWO = 2; -constexpr int64_t MERGE_LIST_THREE = 3; -constexpr int64_t MERGE_LIST_FOUR = 4; - -constexpr int64_t MERGE_LIST_IDX_TWO = 2; -constexpr int64_t MERGE_LIST_IDX_THREE = 3; - -constexpr int64_t NORM_TYPE_SOFTMAX = 0; -constexpr int64_t NORM_TYPE_SIGMOID = 1; - -__aicore__ inline int64_t Ceil(int64_t a, int64_t b) -{ - if (b == 0) { - return 0; - } - return (a + b - 1) / b; -} - -__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) -{ - if (bytes == 0) { - return 0; - } - return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes; -} - -__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) -{ - return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES; -} - -template -__aicore__ inline T Min(T a, T b) -{ - return a > b ? b : a; -} - -template -__aicore__ inline T Max(T a, T b) -{ - return a < b ? b : a; -} - -template -__aicore__ inline T1 CeilDiv(T1 x, T2 y) -{ - if (y != 0 && x != 0) { - const T1 quotient = x / y; - return (x % y != 0 && ((x ^ y) >= 0)) ? (quotient + 1) : quotient; - } - - return x; -} - -} // namespace MoeGatingTopK -#endif // MOE_GATING_TOP_K_COMMON_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/error_log.h b/csrc/moe_gating_top_k/op_kernel/error_log.h deleted file mode 100644 index 8b64eced..00000000 --- a/csrc/moe_gating_top_k/op_kernel/error_log.h +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ -#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ - -#include -#include "toolchain/slog.h" - -#define OP_LOGI(opname, ...) -#define OP_LOGW(opname, ...) \ - do { \ - printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ - do { \ - printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE(opname, ...) \ - do { \ - printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGD(opname, ...) - -namespace optiling { - -#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ - do { \ - OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ - } while (0) - -// 修改 OP_TILING_CHECK 宏,确保正确处理表达式 -#define OP_CHECK_IF(cond, log_func, expr) \ - do { \ - if (cond) { \ - log_func; \ - expr; \ - } \ - } while (0) - - - -#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ - do { \ - if ((ptr) == nullptr) { \ - OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -} // namespace optiling - -#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp deleted file mode 100644 index 4390fee4..00000000 --- a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k.cpp - * \brief - */ - -#include "moe_gating_top_k_e_k_fullload.h" -#include "moe_gating_top_k_without_group.h" -#include "moe_gating_top_k_generalized.h" -#include "error_log.h" - -#define TILING_KEY_PER_GROUP_COUNT_32 0 -#define TILING_KEY_WITHOUT_GROUP 1 -#define TILING_KEY_GENERALIZED 2 - -using namespace AscendC; -using namespace MoeGatingTopK; -extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, - GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) -{ - - KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); - if (g_coreType == AIC) { - return; - } - - - GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKTilingData, tilingData, tiling); - if (workspace == nullptr) { - return; - } - - GM_ADDR userWS = GetUserWorkspace(workspace); - if (userWS == nullptr) { - return; - } - - const MoeGatingTopKTilingData *__restrict t = &tilingData; - TPipe tPipe; - if (TILING_KEY_IS(TILING_KEY_PER_GROUP_COUNT_32)) { - MoeGatingTopKEKFullload op; - op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); - op.Process(); - } else if (TILING_KEY_IS(TILING_KEY_WITHOUT_GROUP)) { - MoeGatingTopKWithoutGroup op; - op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); - op.Process(); - } else if (TILING_KEY_IS(TILING_KEY_GENERALIZED)) { - MoeGatingTopKGenerlized op; - op.Init(x, bias, y, expertIdx, out, userWS, t, &tPipe); - op.Process(); - } - -} diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp deleted file mode 100644 index 5abaf2ba..00000000 --- a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_apt.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_apt.cpp - * \brief - */ - -#include "arch35/moe_gating_top_k_regbase.h" -using namespace AscendC; -using namespace MoeGatingTopK; - -#define TILING_KEY_REGBASE 10000 - -extern "C" __global__ __aicore__ void moe_gating_top_k(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, - GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) -{ - if (g_coreType == AIC) { - return; - } - - if (workspace == nullptr) { - return; - } - - GM_ADDR userWS = GetUserWorkspace(workspace); - if (userWS == nullptr) { - return; - } - - GET_TILING_DATA_WITH_STRUCT(MoeGatingTopKRegbaseTilingData, tiling_data_in, tiling); - const MoeGatingTopKRegbaseTilingData *__restrict tilingData = &tiling_data_in; - TPipe tPipe; - if (TILING_KEY_IS(TILING_KEY_REGBASE)) { - MoeGatingTopKRegbase op; - op.Init(x, bias, y, expertIdx, out, userWS, tilingData, &tPipe); - op.Process(); - } -} diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h deleted file mode 100644 index 6891cc0e..00000000 --- a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_e_k_fullload.h +++ /dev/null @@ -1,404 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_e_k_fullload.h - * \brief - */ -#ifndef MOE_GATING_TOP_K_E_K_FULLLOAD_H -#define MOE_GATING_TOP_K_E_K_FULLLOAD_H -#include "kernel_operator.h" -#include "common.h" -namespace MoeGatingTopK { -using namespace AscendC; - -template -class MoeGatingTopKEKFullload { -public: - __aicore__ inline MoeGatingTopKEKFullload(){}; - __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); - __aicore__ inline void Process(); - -private: - __aicore__ inline void CopyInBias(); - __aicore__ inline void CopyInX(int64_t progress); - __aicore__ inline void ComputeX(); - __aicore__ inline void SortInGroup(); - __aicore__ inline void SelectTopKGroupIndex(); - __aicore__ inline void SelectTopKExpertIdx(); - __aicore__ inline void SelectTopKExpertScore(); - __aicore__ inline void CopyOut(int64_t progress); - -private: - TPipe *pipe_; - TQue xInQueue_; - TBuf biasInQueue_; - TQue yOutQueue_; - TQue expertIdxOutQueue_; - TQue outOutQueue_; - - TQue xBiasQueue_; - TQue xSigmoidQueue_; - TQue sigmoidTmpQueue_; - TQue sortedInGroupQueue_; - TQue sortedGroupQueue_; - TBuf calcTmpBuffer_; - - GlobalTensor xGm_; - GlobalTensor biasGm_; - GlobalTensor yGm_; - GlobalTensor expertIdxGm_; - GlobalTensor outGm_; - - int64_t blockIdx_; - int64_t perCoreRowCount_; - int64_t curCoreRowCount_; - int64_t expertCount_; - bool addBias_; - int64_t k_; - int64_t kGroup_; - int64_t groupCount_; - int64_t groupSelectMode_; - int64_t renorm_; - int64_t normType_; - int64_t outFlag_; - float routedScalingFactor_; - float eps_; - - int64_t expertCountAlign_; - int64_t kAlign_; - int64_t perGroupExpertCount_; - - const MoeGatingTopKTilingData *tilingData_; -}; - -template -__aicore__ inline void MoeGatingTopKEKFullload::CopyInBias() -{ - LocalTensor biasTensor = biasInQueue_.Get(); - DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if constexpr (IsSameType::value) { - DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - } else { - DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, expertCount_); - } -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::CopyInX(int64_t row) -{ - LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); - DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if constexpr (IsSameType::value) { - DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); - } else { - DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, - dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, - expertCount_); - } - - xInQueue_.EnQue(xInLocalTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::ComputeX() -{ - LocalTensor xSigmoidTensor = xSigmoidQueue_.AllocTensor(); - LocalTensor xInLocalTensor = xInQueue_.DeQue(); - LocalTensor xBiasTensor = xBiasQueue_.AllocTensor(); - LocalTensor biasTensor = biasInQueue_.Get(); - LocalTensor sharedTmpBuffer = sigmoidTmpQueue_.AllocTensor(); // 临时空间可以复用 - Sigmoid(xSigmoidTensor, xInLocalTensor, sharedTmpBuffer, expertCount_); - PipeBarrier(); - if (addBias_) { - Add(xBiasTensor, xSigmoidTensor, biasTensor, expertCount_); - } else { - Adds(xBiasTensor, xSigmoidTensor, static_cast(0), expertCount_); - } - - xSigmoidQueue_.EnQue(xSigmoidTensor); - xBiasQueue_.EnQue(xBiasTensor); - xInQueue_.FreeTensor(xInLocalTensor); - sigmoidTmpQueue_.FreeTensor(sharedTmpBuffer); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::SortInGroup() -{ - LocalTensor xBiasTensor = xBiasQueue_.DeQue(); - LocalTensor sortedInGroupTensor = sortedInGroupQueue_.AllocTensor(); // 组内排序的结果, 后续归并需要 - LocalTensor indexTensor = calcTmpBuffer_.Get(); // 用于存储排序时的索引 - ArithProgression(indexTensor.ReinterpretCast(), 0, 1, expertCount_); // 生成组索引0 1 2 ...... - PipeBarrier(); - Sort32(sortedInGroupTensor, xBiasTensor, indexTensor, expertCount_ / ONE_REPEAT_SORT_NUM); // 组内排序 - sortedInGroupQueue_.EnQue(sortedInGroupTensor); - xBiasQueue_.FreeTensor(xBiasTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKGroupIndex() -{ - LocalTensor sortedInGroupTensor = sortedInGroupQueue_.DeQue(); - LocalTensor indexTensor = calcTmpBuffer_.Get(); - LocalTensor top2ValueInGroupTensor = sigmoidTmpQueue_.AllocTensor(); // 这个临时空间可以复用 - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - - indexTensor.SetValue(0, static_cast(5)); // b0101 - indexTensor.SetValue(1, static_cast(0)); - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = 8; - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = 8; - gatherMaskParams.src1RepeatStride = 0; - GatherMask(top2ValueInGroupTensor, sortedInGroupTensor, indexTensor, true, static_cast(64), - gatherMaskParams, rsvdCnt); - PipeBarrier(); - LocalTensor groupTop2SumTensor = top2ValueInGroupTensor; - PairReduceSum(groupTop2SumTensor, top2ValueInGroupTensor, 1, groupCount_ * 2, 1, 1, - 1); // 计算每个组内最大的两个数之和 - PipeBarrier(); - - LocalTensor groupIndexTensor = indexTensor; - ArithProgression(groupIndexTensor.ReinterpretCast(), 0, 1, groupCount_); // 生成组索引 - PipeBarrier(); - // 用最小值补到32个数 - int64_t duplicateNum = ONE_REPEAT_SORT_NUM - groupCount_; - if (duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX << groupCount_; - uint64_t mask[2] = {mask0, 0}; - Duplicate(groupTop2SumTensor, MIN_FP32, mask, 1, 1, 8); - PipeBarrier(); - } - // 排序,将kgroup选出来 - LocalTensor sortedGroupTensor = sortedGroupQueue_.AllocTensor(); - Sort32(sortedGroupTensor, groupTop2SumTensor, groupIndexTensor, 1); - - PipeBarrier(); - LocalTensor sortedGroupIndexTensor = indexTensor.ReinterpretCast(); - // 提取组序号 - uint8_t src1Pattern = 2; // 内置固定模式 - GatherMask(sortedGroupIndexTensor, sortedGroupTensor.template ReinterpretCast(), src1Pattern, false, - static_cast(0), {1, 1, 0, 0}, rsvdCnt); - - // 需要将组排序(这里是降序,所以下mrgsor的时候反着取,3、2、1、0) - Cast(sortedGroupTensor, sortedGroupIndexTensor, RoundMode::CAST_ROUND, kGroup_); - PipeBarrier(); - duplicateNum = ONE_REPEAT_SORT_NUM - kGroup_; - if (duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX << kGroup_; - uint64_t mask[2] = {mask0, 0}; - Duplicate(sortedGroupTensor, MIN_FP32, mask, 1, 1, 8); - PipeBarrier(); - } - Sort32(top2ValueInGroupTensor, sortedGroupTensor, sortedGroupIndexTensor.template ReinterpretCast(), 1); - PipeBarrier(); - src1Pattern = 1; - GatherMask(sortedGroupTensor, top2ValueInGroupTensor, src1Pattern, false, static_cast(0), {1, 1, 0, 0}, - rsvdCnt); - PipeBarrier(); - Cast(sortedGroupIndexTensor, sortedGroupTensor, RoundMode::CAST_ROUND, kGroup_); - - sortedGroupQueue_.FreeTensor(sortedGroupTensor); - sortedInGroupQueue_.EnQue(sortedInGroupTensor); - sigmoidTmpQueue_.FreeTensor(top2ValueInGroupTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKExpertIdx() -{ - LocalTensor expertIdxTensor = expertIdxOutQueue_.AllocTensor(); - LocalTensor topKGroupIndexTensor = calcTmpBuffer_.Get(); - LocalTensor sortedInGroupTensor = sortedInGroupQueue_.DeQue(); - LocalTensor sortedExpertTensor = xInQueue_.AllocTensor(); - AscendC::MrgSort4Info params; - params.elementLengths[0] = k_; - params.elementLengths[1] = k_; - params.elementLengths[2] = k_; - params.elementLengths[3] = k_; - params.ifExhaustedSuspension = true; - params.validBit = 0b1111; - params.repeatTimes = 1; - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - int64_t listOffset1 = topKGroupIndexTensor.GetValue(3) * perGroupExpertCount_ * 2; - int64_t listOffset2 = topKGroupIndexTensor.GetValue(2) * perGroupExpertCount_ * 2; - int64_t listOffset3 = topKGroupIndexTensor.GetValue(1) * perGroupExpertCount_ * 2; - int64_t listOffset4 = topKGroupIndexTensor.GetValue(0) * perGroupExpertCount_ * 2; - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - AscendC::MrgSortSrcList srcList; - srcList.src1 = sortedInGroupTensor[listOffset1]; - srcList.src2 = sortedInGroupTensor[listOffset2]; - srcList.src3 = sortedInGroupTensor[listOffset3]; - srcList.src4 = sortedInGroupTensor[listOffset4]; - MrgSort(sortedExpertTensor, srcList, params); - PipeBarrier(); - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - uint8_t src1Pattern = 2; // 内置固定模式 - GatherMask(expertIdxTensor, sortedExpertTensor.template ReinterpretCast(), src1Pattern, false, - static_cast(0), {1, 1, 0, 0}, rsvdCnt); - xInQueue_.FreeTensor(sortedExpertTensor); - expertIdxOutQueue_.EnQue(expertIdxTensor); - sortedInGroupQueue_.FreeTensor(sortedInGroupTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::SelectTopKExpertScore() -{ - LocalTensor expertIdxTensor = expertIdxOutQueue_.DeQue(); - LocalTensor expertByteIdxTensor = calcTmpBuffer_.Get(); - LocalTensor xSigmoidTensor = xSigmoidQueue_.DeQue(); - LocalTensor yTensor = yOutQueue_.AllocTensor(); - LocalTensor yOutTensor; - if constexpr (!IsSameType::value) { - yOutTensor = yTensor.template ReinterpretCast()[kAlign_]; - } else { - yOutTensor = yTensor; - } - Muls(expertByteIdxTensor, expertIdxTensor, static_cast(sizeof(float)), k_); - PipeBarrier(); - Gather(yOutTensor, xSigmoidTensor, expertByteIdxTensor.template ReinterpretCast(), - static_cast(0), k_); - - LocalTensor calTensor = calcTmpBuffer_.Get(); - PipeBarrier(); - ReduceSum(calTensor, yOutTensor, xSigmoidTensor, k_); - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float sumValue = calTensor.GetValue(0) + eps_; - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Duplicate(calTensor, sumValue, k_); - PipeBarrier(); - Div(yOutTensor, yOutTensor, calTensor, k_); - PipeBarrier(); - Muls(yOutTensor, yOutTensor, routedScalingFactor_, k_); - - if constexpr (!IsSameType::value) { - PipeBarrier(); - Cast(yTensor, yOutTensor, RoundMode::CAST_RINT, k_); - } - - xSigmoidQueue_.EnQue(xSigmoidTensor); - expertIdxOutQueue_.EnQue(expertIdxTensor); - yOutQueue_.EnQue(yTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::CopyOut(int64_t row) -{ - LocalTensor yOutTensor = yOutQueue_.DeQue(); - LocalTensor expertIdxTensor = expertIdxOutQueue_.DeQue(); - LocalTensor xSigmoidTensor = xSigmoidQueue_.DeQue(); - DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; - DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); - dataCopyParams.blockLen = k_ * sizeof(int32_t); - DataCopyPad(expertIdxGm_[row * k_], expertIdxTensor, dataCopyParams); - xSigmoidQueue_.FreeTensor(xSigmoidTensor); - expertIdxOutQueue_.FreeTensor(expertIdxTensor); - yOutQueue_.FreeTensor(yOutTensor); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, - GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) -{ - tilingData_ = tilingData; - pipe_ = tPipe; - blockIdx_ = GetBlockIdx(); - perCoreRowCount_ = tilingData_->perCoreRowCount; - if (blockIdx_ == GetBlockNum() - 1) { - curCoreRowCount_ = tilingData_->lastCoreRowCount; - } else { - curCoreRowCount_ = tilingData_->perCoreRowCount; - } - expertCount_ = tilingData_->expertCount; - addBias_ = tilingData_->addBias == 1; - k_ = tilingData_->k; - kGroup_ = tilingData_->kGroup; - groupCount_ = tilingData_->groupCount; - perGroupExpertCount_ = tilingData_->perGroupExpertCount; - routedScalingFactor_ = tilingData_->routedScalingFactor; - eps_ = tilingData_->eps; - - expertCountAlign_ = Align(expertCount_, sizeof(float)); - kAlign_ = Align(expertCount_, sizeof(float)); - - // init input gm buf - xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); - - // init output gm buf - yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); - expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); - outGm_.SetGlobalBuffer((__gm__ T *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - - // init que - pipe_->InitBuffer(xInQueue_, 2, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(biasInQueue_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - - pipe_->InitBuffer(xSigmoidQueue_, 1, AlignBytes(expertCount_, sizeof(float))); - pipe_->InitBuffer(xBiasQueue_, 2, AlignBytes(expertCount_, sizeof(float))); - - pipe_->InitBuffer(yOutQueue_, 2, kAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(expertIdxOutQueue_, 2, AlignBytes(k_, sizeof(int32_t))); - pipe_->InitBuffer(outOutQueue_, 2, AlignBytes(expertCount_, sizeof(float))); - - pipe_->InitBuffer(sigmoidTmpQueue_, 2, AlignBytes(expertCount_, sizeof(float))); - pipe_->InitBuffer(sortedInGroupQueue_, 2, AlignBytes(expertCount_, sizeof(float)) * 2); - pipe_->InitBuffer(sortedGroupQueue_, 2, - (groupCount_ + ONE_REPEAT_SORT_NUM - 1) / ONE_REPEAT_SORT_NUM * ONE_REPEAT_SORT_NUM * - sizeof(float) * 2); - - pipe_->InitBuffer(calcTmpBuffer_, tilingData_->calTmpBufUbSize); -} - -template -__aicore__ inline void MoeGatingTopKEKFullload::Process() -{ - CopyInBias(); - for (int64_t row = 0; row < curCoreRowCount_; row++) { - CopyInX(row); - ComputeX(); - SortInGroup(); - SelectTopKGroupIndex(); - SelectTopKExpertIdx(); - SelectTopKExpertScore(); - CopyOut(row); - } -} -} // namespace MoeGatingTopK -#endif // MOE_GATING_TOP_K_E_K_FULLLOAD_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h deleted file mode 100644 index 8f37a0fa..00000000 --- a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_generalized.h +++ /dev/null @@ -1,669 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_generalized.h - * \brief - */ -#ifndef MOE_GATING_TOP_K_E_K_GENERALIZED_H -#define MOE_GATING_TOP_K_E_K_GENERALIZED_H -#include "kernel_operator.h" -#include "common.h" -#include "kernel_utils.h" -namespace MoeGatingTopK { -using namespace AscendC; - -template -class MoeGatingTopKGenerlized { -public: - __aicore__ inline MoeGatingTopKGenerlized(){}; - __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); - __aicore__ inline void Process(); - -private: - __aicore__ inline void CopyInBiasAndInitExpertId(); - __aicore__ inline void CopyInX(int64_t progress); - __aicore__ inline void ComputeX(); - __aicore__ inline void CopuOutXNorm(int64_t row); - __aicore__ inline void SortInGroup(); - __aicore__ inline void SelectTopKGroupIndex(); - __aicore__ inline void SelectTopKExpertIdx(); - __aicore__ inline void SelectTopKExpertScore(); - __aicore__ inline void CumputeActualTopKExpertId(); - __aicore__ inline void CopyOut(int64_t row); - -private: - TPipe *pipe_; - TQue xInQueue_; - TQue yOutQueue_; - TQue expertIdxOutQueue_; - TQue outOutQueue_; - - TBuf biasBuf_; // 存放输入bias - TBuf expertIdBuf_; // 专家编号 - TBuf xNormWithBiasBuf_; // 存放加了bias之后的值 - TBuf xNormBuf_; // 存放计算sigmoid或softmax的值 - TBuf sortedInGroupBuf_; // 存放组内排序后的结果 - TBuf topKExpertIdBuf_; - TBuf sortedGroupIndexBuf_; - TBuf calcTmpBuf_; - - GlobalTensor xGm_; - GlobalTensor biasGm_; - GlobalTensor yGm_; - GlobalTensor expertIdxGm_; - GlobalTensor outGm_; - - int64_t blockIdx_ = 0; - int64_t perCoreRowCount_ = 0; - int64_t curCoreRowCount_ = 0; - int64_t expertCount_ = 0; - bool addBias_ = false; - int64_t k_ = 0; - int64_t kGroup_ = 0; - int64_t groupCount_ = 0; - int64_t groupCountAlign_ = 0; - int64_t perGroupExpertCount_ = 0; - int64_t perGroupExpertCountAlign_ = 0; - int64_t groupSelectMode_ = 0; - int64_t renorm_ = 0; - int64_t normType_ = 0; - int64_t outFlag_ = 0; - - int64_t expertCountAlign_ = 0; - int64_t kAlign_ = 0; - bool isAlign_ = false; - - const MoeGatingTopKTilingData *tilingData_; -}; - -template -__aicore__ inline void MoeGatingTopKGenerlized::CopyInBiasAndInitExpertId() -{ - LocalTensor biasTensor = biasBuf_.Get(); - LocalTensor expertIdTensor = expertIdBuf_.Get(); - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = groupCount_; - dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T); - dataCopyParams.srcStride = 0; - dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES; - - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if (addBias_) { - if constexpr (IsSameType::value) { - DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - } else { - DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, - expertCountAlign_); - PipeBarrier(); - } - - if (!isAlign_) { - int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM; - int duplicateIndex = perGroupExpertCount_ - duplicateNum; - if (duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - Duplicate(biasTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1, - perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES); - } - } - } - ArithProgression(expertIdTensor, static_cast(0), static_cast(1), expertCountAlign_); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::CopyInX(int64_t row) -{ - LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = groupCount_; - dataCopyParams.blockLen = perGroupExpertCount_ * sizeof(T); - dataCopyParams.srcStride = 0; - dataCopyParams.dstStride = (perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(T) / BLOCK_BYTES; - - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if constexpr (IsSameType::value) { - DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); - } else { - DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, - dataCopyPadParams); - } - xInQueue_.EnQue(xInLocalTensor); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::ComputeX() -{ - LocalTensor xNormTensor = xNormBuf_.Get(); - LocalTensor xInLocalTensor = xInQueue_.DeQue(); - LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); - LocalTensor biasTensor = biasBuf_.Get(); - - if constexpr (!IsSameType::value) { - Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, - expertCountAlign_); - PipeBarrier(); - } - - int64_t duplicateNum = perGroupExpertCount_ % ONE_REPEAT_SORT_NUM; - int duplicateIndex = perGroupExpertCount_ - duplicateNum; - if (!isAlign_ && duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - Duplicate(xInLocalTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, groupCount_, 1, - (perGroupExpertCountAlign_ * sizeof(float)) / BLOCK_BYTES); - PipeBarrier(); - } - if (normType_ == 1) { // sigmoid - LocalTensor calcNormTmpTensor = calcTmpBuf_.Get(); - Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCountAlign_); - PipeBarrier(); - } - else if (normType_ == 0) { // softmax - LocalTensor reduceValueTensor = calcTmpBuf_.Get(); - LocalTensor calcTmp = calcTmpBuf_.Get()[BLOCK_BYTES]; - ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCountAlign_); - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float maxValue = reduceValueTensor.GetValue(0); - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Adds(xNormTensor, xInLocalTensor, -maxValue, expertCountAlign_); - PipeBarrier(); - Exp(xNormTensor, xNormTensor, expertCountAlign_); - PipeBarrier(); - ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCountAlign_); - eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float sumValue = reduceValueTensor.GetValue(0); - eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCountAlign_); - PipeBarrier(); - } - if (addBias_) { - Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCountAlign_); - } else { - DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_); - } - - if (!isAlign_ && duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - PipeBarrier(); - Duplicate(xNormWithBiasTensor.ReinterpretCast()[duplicateIndex], - FLOAT32_NEG_INF, // MIN_FP32, - mask, groupCount_, 1, perGroupExpertCountAlign_ * sizeof(float) / BLOCK_BYTES); - } - xInQueue_.FreeTensor(xInLocalTensor); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::CopuOutXNorm(int64_t row) -{ - LocalTensor outOutTensor = outOutQueue_.AllocTensor(); - LocalTensor xNormTensor = xNormBuf_.Get(); - DataCopy(outOutTensor, xNormTensor, expertCountAlign_); - outOutQueue_.EnQue(outOutTensor); - outOutTensor = outOutQueue_.DeQue(); - DataCopyExtParams dataCopyParams{ - static_cast(groupCount_), static_cast(perGroupExpertCount_ * sizeof(float)), - static_cast((perGroupExpertCountAlign_ - perGroupExpertCount_) * sizeof(float) / BLOCK_BYTES), 0, 0}; - DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams); - outOutQueue_.FreeTensor(outOutTensor); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::SortInGroup() -{ - LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); - LocalTensor expertIdTensor = expertIdBuf_.Get(); - LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); - LocalTensor tmpLocal = calcTmpBuf_.Get(); - if (perGroupExpertCountAlign_ == ONE_REPEAT_SORT_NUM) { - PipeBarrier(); - Sort32(sortedInGroupTensor, xNormWithBiasTensor, expertIdTensor, groupCount_); - } else { - for (int64_t group = 0; group < groupCount_; group++) { - PipeBarrier(); - Sort(sortedInGroupTensor[group * perGroupExpertCountAlign_ * CONSTANT_TWO], - xNormWithBiasTensor[group * perGroupExpertCountAlign_], - expertIdTensor[group * perGroupExpertCountAlign_], tmpLocal, - perGroupExpertCountAlign_ / ONE_REPEAT_SORT_NUM); - } - } -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKGroupIndex() -{ - LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); - LocalTensor valueSelectedFromGroupTensor = calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, 0); - LocalTensor maskTensor = - calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 2 * sizeof(float)); - LocalTensor topValueInGroupTensor = - calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 3 * sizeof(float)); - LocalTensor groupIndex = - calcTmpBuf_.GetWithOffset(groupCountAlign_, groupCountAlign_ * 4 * sizeof(float)); - LocalTensor sortedTopValue = - calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, groupCountAlign_ * 5 * sizeof(float)); - LocalTensor sortTmp = - calcTmpBuf_.GetWithOffset(groupCountAlign_ * 2, groupCountAlign_ * 7 * sizeof(float)); - - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - PipeBarrier(); - if (groupSelectMode_ == 1) { // top2 sum - // 提取每组组前两个元素 - maskTensor.SetValue(0, static_cast(5)); // b0101 - maskTensor.SetValue(1, static_cast(0)); - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - - GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = groupCount_; - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = - Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), BLOCK_BYTES); - gatherMaskParams.src1RepeatStride = 0; - GatherMask(valueSelectedFromGroupTensor, sortedInGroupTensor, maskTensor, true, - static_cast(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt); - PipeBarrier(); - - // 计算每个组前两个数的和 - PairReduceSum(topValueInGroupTensor, valueSelectedFromGroupTensor, - Ceil(groupCount_ * sizeof(float) * 2, REPEAT_BYTES), REPEAT_BYTES / sizeof(float), 1, 1, - CONSTANT_EIGHT); // 计算每个组内最大的两个数之和 - } else { - maskTensor.SetValue(0, static_cast(1)); // b0101 - maskTensor.SetValue(1, static_cast(0)); - - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = groupCount_; - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = Ceil(perGroupExpertCountAlign_ * (sizeof(float) + sizeof(uint32_t)), 32); - gatherMaskParams.src1RepeatStride = 0; - GatherMask(topValueInGroupTensor, sortedInGroupTensor, maskTensor, true, - static_cast(ONE_REPEAT_SORT_NUM * CONSTANT_TWO), gatherMaskParams, rsvdCnt); - } - - PipeBarrier(); - // 生成组索引 - ArithProgression(groupIndex.ReinterpretCast(), static_cast(0), static_cast(1), - groupCount_); // 生成组索引 - PipeBarrier(); - - int64_t duplicateNum = groupCount_ % ONE_REPEAT_SORT_NUM; - int duplicateIndex = groupCount_ - duplicateNum; - if (duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - Duplicate(topValueInGroupTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, - REPEAT_BLOCKS); - PipeBarrier(); - } - PipeBarrier(); - - // 排序 - Sort(sortedTopValue, topValueInGroupTensor, groupIndex, sortTmp, Ceil(groupCount_, 32)); - PipeBarrier(); - - // 提取组序号 - uint8_t src1Pattern = 2; // 内置固定模式 - GatherMask(groupIndex, sortedTopValue.template ReinterpretCast(), src1Pattern, false, - static_cast(0), - {1, static_cast(Ceil(kGroup_ * sizeof(float) * CONSTANT_TWO, 256)), REPEAT_BLOCKS, 0}, rsvdCnt); - PipeBarrier(); - duplicateNum = kGroup_ % ONE_REPEAT_SORT_NUM; - if (duplicateNum > 0) { - duplicateIndex = kGroup_ - duplicateNum; - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - PipeBarrier(); - Duplicate(groupIndex.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, REPEAT_BLOCKS); - } - - // 将筛选出来的组序号降序排列 - LocalTensor sortedGroupIndex = sortedGroupIndexBuf_.Get(); - PipeBarrier(); - Sort(sortedGroupIndex, groupIndex.ReinterpretCast(), groupIndex, sortTmp, Ceil(kGroup_, 32)); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKExpertIdx() -{ - LocalTensor sortedInGroupTensor = sortedInGroupBuf_.Get(); - LocalTensor sortedGroupIndex = sortedGroupIndexBuf_.Get(); - LocalTensor topKExpertId = topKExpertIdBuf_.Get(); - LocalTensor mrgSort0Tensor = calcTmpBuf_.Get(); - - uint32_t offset[CONSTANT_FOUR] = {0, 0, 0, 0}; - uint16_t lenArr[CONSTANT_FOUR] = { - static_cast(perGroupExpertCount_), static_cast(perGroupExpertCount_), - static_cast(perGroupExpertCount_), static_cast(perGroupExpertCount_)}; - MrgSort4Info params{lenArr, false, 0b1111, 1}; - MrgSortSrcList srcList; - - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - - for (int32_t i = kGroup_ - 1; i >= 0; i -= CONSTANT_FOUR) { - int64_t mrgLen = Min(i + 1, CONSTANT_FOUR); - if (mrgLen > 1) { - if (mrgLen == MERGE_LIST_FOUR) { - offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; - offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; - offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2; - offset[3] = sortedGroupIndex.GetValue((i - 3) * 2) * perGroupExpertCountAlign_ * 2; - } else if (mrgLen == MERGE_LIST_THREE) { - offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; - offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; - offset[2] = sortedGroupIndex.GetValue((i - 2) * 2) * perGroupExpertCountAlign_ * 2; - offset[3] = 0; - params.elementLengths[3] = 0; - params.validBit = 0b111; - } else { - offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; - offset[1] = sortedGroupIndex.GetValue((i - 1) * 2) * perGroupExpertCountAlign_ * 2; - offset[2] = 0; - offset[3] = 0; - params.elementLengths[2] = 0; - params.elementLengths[3] = 0; - params.validBit = 0b11; - } - - srcList.src1 = sortedInGroupTensor[offset[0]]; - srcList.src2 = sortedInGroupTensor[offset[1]]; - srcList.src3 = sortedInGroupTensor[offset[2]]; - srcList.src4 = sortedInGroupTensor[offset[3]]; - - PipeBarrier(); - MrgSort(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], srcList, params); - } else { - offset[0] = sortedGroupIndex.GetValue(i * 2) * perGroupExpertCountAlign_ * 2; - PipeBarrier(); - DataCopy(mrgSort0Tensor[(kGroup_ - 1 - i) * perGroupExpertCountAlign_ * 2], sortedInGroupTensor[offset[0]], - perGroupExpertCountAlign_ * 2); - } - } - int32_t baseLoop = 4; - LocalTensor srcTensor = mrgSort0Tensor; - LocalTensor dstTensor = mrgSort0Tensor; - for (int i = 0; i < tilingData_->vmsCount; i++) { - if (i % 2 == 0) { - srcTensor = mrgSort0Tensor; - dstTensor = sortedInGroupTensor; - } else { - srcTensor = sortedInGroupTensor; - dstTensor = mrgSort0Tensor; - } - - int32_t nextBaseRow = baseLoop * MERGE_LIST_FOUR; - int32_t quotient = kGroup_ / nextBaseRow; - int32_t remainder = kGroup_ - quotient * nextBaseRow; - if (quotient > 0) { - MrgSort4Info params; - MrgSortSrcList srcList; - params.ifExhaustedSuspension = false; - params.elementLengths[0] = perGroupExpertCount_ * baseLoop; - params.elementLengths[1] = perGroupExpertCount_ * baseLoop; - params.elementLengths[2] = perGroupExpertCount_ * baseLoop; - params.elementLengths[3] = perGroupExpertCount_ * baseLoop; - params.validBit = 0b1111; - params.repeatTimes = 1; - for (int j = 0; j < quotient; j++) { - srcList.src1 = srcTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j]; - srcList.src2 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 2)]; - srcList.src3 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 4)]; - srcList.src4 = srcTensor[perGroupExpertCountAlign_ * baseLoop * (8 * j + 6)]; - PipeBarrier(); - MrgSort(dstTensor[perGroupExpertCountAlign_ * baseLoop * 8 * j], srcList, params); - } - } - - if (remainder > 0) { - int32_t baseOffset = quotient * nextBaseRow * perGroupExpertCountAlign_ * 2; - int32_t mrgLen = CeilDiv(remainder, baseLoop); - int32_t tailRow = remainder - (mrgLen - 1) * baseLoop; - if (mrgLen > 1) { - MrgSort4Info params; - MrgSortSrcList srcList; - params.repeatTimes = 1; - params.ifExhaustedSuspension = false; - params.elementLengths[0] = perGroupExpertCount_ * baseLoop; - params.elementLengths[1] = perGroupExpertCount_ * baseLoop; - params.elementLengths[2] = perGroupExpertCount_ * baseLoop; - params.elementLengths[3] = perGroupExpertCount_ * baseLoop; - srcList.src1 = srcTensor[baseOffset]; - srcList.src2 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2]; - if (mrgLen == MERGE_LIST_FOUR) { - srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2]; - srcList.src4 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 3]; - params.elementLengths[3] = perGroupExpertCount_ * tailRow; - params.validBit = 0b1111; - } else if (mrgLen == MERGE_LIST_THREE) { - srcList.src3 = srcTensor[baseOffset + perGroupExpertCountAlign_ * baseLoop * 2 * 2]; - params.elementLengths[2] = perGroupExpertCount_ * tailRow; - params.elementLengths[3] = 0; - params.validBit = 0b111; - } else { - params.elementLengths[1] = perGroupExpertCount_ * tailRow; - params.elementLengths[2] = 0; - params.elementLengths[3] = 0; - params.validBit = 0b11; - } - PipeBarrier(); - MrgSort(dstTensor[baseOffset], srcList, params); - } else { - PipeBarrier(); - DataCopy(dstTensor[baseOffset], srcTensor[baseOffset], tailRow * perGroupExpertCountAlign_ * 2); - } - } - baseLoop = nextBaseRow; - } - - GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * 2, REPEAT_BYTES); - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS; - gatherMaskParams.src1RepeatStride = 0; - - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - uint8_t src1Pattern = 2; // 内置固定模式 - PipeBarrier(); - GatherMask(topKExpertId, dstTensor.template ReinterpretCast(), src1Pattern, false, - static_cast(0), gatherMaskParams, rsvdCnt); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::SelectTopKExpertScore() -{ - LocalTensor xNormTensor = xNormBuf_.Get(); - LocalTensor yOutTensor = yOutQueue_.AllocTensor(); - LocalTensor topKExpertId = topKExpertIdBuf_.Get(); - LocalTensor topKExpertIdWithByte = calcTmpBuf_.Get(); - PipeBarrier(); - Muls(topKExpertIdWithByte, topKExpertId, static_cast(sizeof(float)), k_); - PipeBarrier(); - Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast(), static_cast(0), - k_); - bool needRenorm = (normType_ == 1 ) || // 情况1:sigmoid + renorm - (normType_ == 0 && renorm_ == 1); // 情况3:softmax + renorm - if (needRenorm) { - LocalTensor maxValueTensor = calcTmpBuf_.Get(); - LocalTensor tmpTensor = calcTmpBuf_.Get()[32]; - PipeBarrier(); - ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_); - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps; - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Duplicate(tmpTensor, sumValue, k_); - PipeBarrier(); - Div(yOutTensor, yOutTensor, tmpTensor, k_); - } - PipeBarrier(); - Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_); - - if constexpr (!IsSameType::value) { - PipeBarrier(); - Cast(yOutTensor.ReinterpretCast(), yOutTensor, RoundMode::CAST_RINT, k_); - } - - yOutQueue_.EnQue(yOutTensor); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::CumputeActualTopKExpertId() -{ - LocalTensor expertIdxOut = expertIdxOutQueue_.AllocTensor(); - LocalTensor topKExpertId = topKExpertIdBuf_.Get(); - LocalTensor topKExpertIdFp32 = calcTmpBuf_.Get(); - - PipeBarrier(); - Cast(topKExpertIdFp32, topKExpertId, RoundMode::CAST_ROUND, k_); - PipeBarrier(); - Muls(topKExpertIdFp32, topKExpertIdFp32, 1.0f / (float)perGroupExpertCountAlign_, k_); - PipeBarrier(); - Cast(expertIdxOut, topKExpertIdFp32, RoundMode::CAST_TRUNC, k_); - PipeBarrier(); - Muls(expertIdxOut, expertIdxOut, static_cast(perGroupExpertCountAlign_ - perGroupExpertCount_), k_); - PipeBarrier(); - Sub(expertIdxOut, topKExpertId, expertIdxOut, k_); - expertIdxOutQueue_.EnQue(expertIdxOut); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::CopyOut(int64_t row) -{ - LocalTensor yOutTensor = yOutQueue_.DeQue(); - LocalTensor expertIdxOut = expertIdxOutQueue_.DeQue(); - DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; - DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); - dataCopyParams.blockLen = k_ * sizeof(int32_t); - DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams); - yOutQueue_.FreeTensor(yOutTensor); - expertIdxOutQueue_.FreeTensor(expertIdxOut); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, - GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) -{ - tilingData_ = tilingData; - pipe_ = tPipe; - blockIdx_ = GetBlockIdx(); - perCoreRowCount_ = tilingData_->perCoreRowCount; - if (blockIdx_ == GetBlockNum() - 1) { - curCoreRowCount_ = tilingData_->lastCoreRowCount; - } else { - curCoreRowCount_ = tilingData_->perCoreRowCount; - } - expertCount_ = tilingData_->expertCount; - addBias_ = tilingData_->addBias == 1; - k_ = tilingData_->k; - kGroup_ = tilingData_->kGroup; - groupCount_ = tilingData_->groupCount; - groupCountAlign_ = Ceil(groupCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; - perGroupExpertCount_ = tilingData_->perGroupExpertCount; - perGroupExpertCountAlign_ = tilingData_->perGroupExpertCountAlign; - renorm_ = tilingData_->renorm; - normType_ = tilingData_->normType; - groupSelectMode_ = tilingData_->groupSelectMode; - - expertCountAlign_ = Align(perGroupExpertCountAlign_ * groupCount_, sizeof(float)); - kAlign_ = Align(k_, sizeof(float)); - - isAlign_ = perGroupExpertCount_ == perGroupExpertCountAlign_; - - // init input gm buf - xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); - - // init output gm buf - yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); - expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); - outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - - // init que - pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(yOutQueue_, 1, kAlign_ * sizeof(float)); - pipe_->InitBuffer(expertIdxOutQueue_, 1, kAlign_ * sizeof(int32_t)); - pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float)); - - pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t)); - - pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float)); - - pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float)); - pipe_->InitBuffer(sortedInGroupBuf_, expertCountAlign_ * (sizeof(float) + sizeof(uint32_t))); - - pipe_->InitBuffer(sortedGroupIndexBuf_, groupCountAlign_ * sizeof(float) * CONSTANT_TWO); - pipe_->InitBuffer(topKExpertIdBuf_, kAlign_ * sizeof(int32_t)); - pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * 10); -} - -template -__aicore__ inline void MoeGatingTopKGenerlized::Process() -{ - CopyInBiasAndInitExpertId(); - for (int64_t row = 0; row < curCoreRowCount_; row++) { - CopyInX(row); - ComputeX(); - if (tilingData_->outFlag) { - CopuOutXNorm(row); - } - SortInGroup(); - SelectTopKGroupIndex(); - SelectTopKExpertIdx(); - SelectTopKExpertScore(); - CumputeActualTopKExpertId(); - CopyOut(row); - } -} -} // namespace MoeGatingTopK -#endif // MOE_GATING_TOP_K_E_K_GENERALIZED_H diff --git a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h b/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h deleted file mode 100644 index ffdc8331..00000000 --- a/csrc/moe_gating_top_k/op_kernel/moe_gating_top_k_without_group.h +++ /dev/null @@ -1,338 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file moe_gating_top_k_without_group.h - * \brief - */ -#ifndef MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H -#define MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H -#include "kernel_operator.h" -#include "common.h" -#include "kernel_utils.h" -namespace MoeGatingTopK { -using namespace AscendC; - -template -class MoeGatingTopKWithoutGroup { -public: - __aicore__ inline MoeGatingTopKWithoutGroup(){}; - __aicore__ inline void Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe); - __aicore__ inline void Process(); - -private: - __aicore__ inline void CopyInBiasAndInitExpertId(); - __aicore__ inline void CopyInX(int64_t progress); - __aicore__ inline void ComputeX(); - __aicore__ inline void CopuOutXNorm(int64_t row); - __aicore__ inline void SelectTopKExpertIdx(); - __aicore__ inline void SelectTopKExpertScore(); - __aicore__ inline void CopyOut(int64_t row); - -private: - TPipe *pipe_; - TQue xInQueue_; - TQue yOutQueue_; - TQue expertIdxOutQueue_; - TQue outOutQueue_; - - TBuf biasBuf_; // 存放输入bias - TBuf expertIdBuf_; // 专家编号 - TBuf xNormWithBiasBuf_; // 存放加了bias之后的值 - TBuf xNormBuf_; // 存放计算sigmoid或softmax的值 - TBuf topKExpertIdBuf_; - TBuf calcTmpBuf_; - - GlobalTensor xGm_; - GlobalTensor biasGm_; - GlobalTensor yGm_; - GlobalTensor expertIdxGm_; - GlobalTensor outGm_; - - int64_t blockIdx_ = 0; - int64_t perCoreRowCount_ = 0; - int64_t curCoreRowCount_ = 0; - int64_t expertCount_ = 0; - bool addBias_ = false; - bool outFlag_ = false; - int64_t k_ = 0; - int64_t renorm_ = 0; - int64_t normType_ = 0; - int64_t expertCountAlign_ = 0; - const MoeGatingTopKTilingData *tilingData_; -}; - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::CopyInBiasAndInitExpertId() -{ - LocalTensor biasTensor = biasBuf_.Get(); - LocalTensor expertIdTensor = expertIdBuf_.Get(); - DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if (addBias_) { - if constexpr (IsSameType::value) { - DataCopyPad(biasTensor, biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - } else { - DataCopyPad(biasTensor[expertCountAlign_].ReinterpretCast(), biasGm_, dataCopyParams, dataCopyPadParams); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Cast(biasTensor, biasTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, - expertCountAlign_); - PipeBarrier(); - } - } - ArithProgression(expertIdTensor, static_cast(0), static_cast(1), expertCount_); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::CopyInX(int64_t row) -{ - LocalTensor xInLocalTensor = xInQueue_.AllocTensor(); - DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(T)), 0, 0, 0}; - DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast(0)}; - if constexpr (IsSameType::value) { - DataCopyPad(xInLocalTensor, xGm_[row * expertCount_], dataCopyParams, dataCopyPadParams); - } else { - DataCopyPad(xInLocalTensor[expertCountAlign_].ReinterpretCast(), xGm_[row * expertCount_], dataCopyParams, - dataCopyPadParams); - } - xInQueue_.EnQue(xInLocalTensor); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::ComputeX() -{ - LocalTensor xNormTensor = xNormBuf_.Get(); - LocalTensor xInLocalTensor = xInQueue_.DeQue(); - LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); - LocalTensor biasTensor = biasBuf_.Get(); - - if constexpr (!IsSameType::value) { - Cast(xInLocalTensor, xInLocalTensor[expertCountAlign_].ReinterpretCast(), RoundMode::CAST_NONE, - expertCount_); - PipeBarrier(); - } - - if (normType_ == 1) { // sigmoid - LocalTensor calcNormTmpTensor = calcTmpBuf_.Get(); - Sigmoid(xNormTensor, xInLocalTensor, calcNormTmpTensor, expertCount_); - PipeBarrier(); - } else if (normType_ == 0) { // sigmoid - LocalTensor reduceValueTensor = calcTmpBuf_.Get(); - LocalTensor calcTmp = calcTmpBuf_.Get()[8]; - ReduceMax(reduceValueTensor, xInLocalTensor, calcTmp, expertCount_); - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float maxValue = reduceValueTensor.GetValue(0); - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Adds(xNormTensor, xInLocalTensor, -maxValue, expertCount_); - PipeBarrier(); - Exp(xNormTensor, xNormTensor, expertCount_); - PipeBarrier(); - ReduceSum(reduceValueTensor, xNormTensor, calcTmp, expertCount_); - eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float sumValue = reduceValueTensor.GetValue(0); - eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Muls(xNormTensor, xNormTensor, 1.0f / sumValue, expertCount_); - PipeBarrier(); - } - if (addBias_) { - Add(xNormWithBiasTensor, xNormTensor, biasTensor, expertCount_); - } else { - DataCopy(xNormWithBiasTensor, xNormTensor, expertCountAlign_); - } - - int64_t duplicateNum = expertCount_ % ONE_REPEAT_SORT_NUM; - int duplicateIndex = expertCount_ - duplicateNum; - if (duplicateNum > 0) { - uint64_t mask0 = UINT64_MAX; - mask0 = mask0 << duplicateNum; - mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); - uint64_t mask[2] = {mask0, 0}; - Duplicate(xNormWithBiasTensor.ReinterpretCast()[duplicateIndex], FLOAT32_NEG_INF, mask, 1, 1, 1); - PipeBarrier(); - } - xInQueue_.FreeTensor(xInLocalTensor); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::CopuOutXNorm(int64_t row) -{ - LocalTensor outOutTensor = outOutQueue_.AllocTensor(); - LocalTensor xNormTensor = xNormBuf_.Get(); - DataCopy(outOutTensor, xNormTensor, expertCountAlign_); - outOutQueue_.EnQue(outOutTensor); - outOutTensor = outOutQueue_.DeQue(); - DataCopyExtParams dataCopyParams{1, static_cast(expertCount_ * sizeof(float)), 0, 0, 0}; - DataCopyPad(outGm_[row * expertCount_], outOutTensor, dataCopyParams); - outOutQueue_.FreeTensor(outOutTensor); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::SelectTopKExpertIdx() -{ - LocalTensor expertIdxOut = expertIdxOutQueue_.AllocTensor(); - LocalTensor xNormWithBiasTensor = xNormWithBiasBuf_.Get(); - LocalTensor expertIdTensor = expertIdBuf_.Get(); - LocalTensor topKExpertId = topKExpertIdBuf_.Get(); - LocalTensor sortedScore = calcTmpBuf_.Get(); - LocalTensor sortTmp = calcTmpBuf_.Get()[expertCountAlign_ * CONSTANT_TWO]; - PipeBarrier(); - Sort(sortedScore, xNormWithBiasTensor, expertIdTensor, sortTmp, - expertCountAlign_ / ONE_REPEAT_SORT_NUM); - - GatherMaskParams gatherMaskParams; - gatherMaskParams.repeatTimes = Ceil(k_ * sizeof(float) * CONSTANT_TWO, REPEAT_BYTES); - gatherMaskParams.src0BlockStride = 1; - gatherMaskParams.src0RepeatStride = REPEAT_BLOCKS; - gatherMaskParams.src1RepeatStride = 0; - - uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 - uint8_t src1Pattern = 2; // 内置固定模式 - PipeBarrier(); - GatherMask(topKExpertId, sortedScore.template ReinterpretCast(), src1Pattern, false, - static_cast(0), gatherMaskParams, rsvdCnt); - - DataCopy(expertIdxOut, topKExpertId, expertCountAlign_); - expertIdxOutQueue_.EnQue(expertIdxOut); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::SelectTopKExpertScore() -{ - LocalTensor xNormTensor = xNormBuf_.Get(); - LocalTensor yOutTensor = yOutQueue_.AllocTensor(); - LocalTensor topKExpertId = topKExpertIdBuf_.Get(); - LocalTensor topKExpertIdWithByte = calcTmpBuf_.Get(); - PipeBarrier(); - Muls(topKExpertIdWithByte, topKExpertId, static_cast(sizeof(float)), k_); - PipeBarrier(); - Gather(yOutTensor, xNormTensor, topKExpertIdWithByte.template ReinterpretCast(), static_cast(0), - k_); - - bool needRenorm = (normType_ == 1 ) || // 情况1:sigmoid + renorm - (normType_ == 0 && renorm_ == 1); // 情况3:softmax + renorm - if (needRenorm == 1) { - LocalTensor maxValueTensor = calcTmpBuf_.Get(); - LocalTensor tmpTensor = calcTmpBuf_.Get()[BLOCK_BYTES]; - PipeBarrier(); - ReduceSum(maxValueTensor, yOutTensor, tmpTensor, k_); - event_t eventIdVToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); - float sumValue = maxValueTensor.GetValue(0) + tilingData_->eps; - event_t eventIdSToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); - SetFlag(eventIdSToV); - WaitFlag(eventIdSToV); - Duplicate(tmpTensor, sumValue, k_); - PipeBarrier(); - Div(yOutTensor, yOutTensor, tmpTensor, k_); - } - PipeBarrier(); - Muls(yOutTensor, yOutTensor, tilingData_->routedScalingFactor, k_); - - if constexpr (!IsSameType::value) { - PipeBarrier(); - Cast(yOutTensor.ReinterpretCast(), yOutTensor, RoundMode::CAST_RINT, k_); - } - - yOutQueue_.EnQue(yOutTensor); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::CopyOut(int64_t row) -{ - LocalTensor yOutTensor = yOutQueue_.DeQue(); - LocalTensor expertIdxOut = expertIdxOutQueue_.DeQue(); - DataCopyExtParams dataCopyParams{1, static_cast(k_ * sizeof(T)), 0, 0, 0}; - DataCopyPad(yGm_[row * k_], yOutTensor, dataCopyParams); - dataCopyParams.blockLen = k_ * sizeof(int32_t); - DataCopyPad(expertIdxGm_[row * k_], expertIdxOut, dataCopyParams); - yOutQueue_.FreeTensor(yOutTensor); - expertIdxOutQueue_.FreeTensor(expertIdxOut); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::Init(GM_ADDR x, GM_ADDR bias, GM_ADDR y, GM_ADDR expertIdx, - GM_ADDR out, GM_ADDR workspace, - const MoeGatingTopKTilingData *tilingData, TPipe *tPipe) -{ - tilingData_ = tilingData; - pipe_ = tPipe; - blockIdx_ = GetBlockIdx(); - perCoreRowCount_ = tilingData_->perCoreRowCount; - if (blockIdx_ == GetBlockNum() - 1) { - curCoreRowCount_ = tilingData_->lastCoreRowCount; - } else { - curCoreRowCount_ = tilingData_->perCoreRowCount; - } - expertCount_ = tilingData_->expertCount; - addBias_ = tilingData_->addBias == 1; - outFlag_ = tilingData_->outFlag == 1; - k_ = tilingData_->k; - renorm_ = tilingData_->renorm; - normType_ = tilingData_->normType; - expertCountAlign_ = Ceil(expertCount_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; - - // init input gm buf - xGm_.SetGlobalBuffer((__gm__ T *)x + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - biasGm_.SetGlobalBuffer((__gm__ T *)bias, expertCount_); - - // init output gm buf - yGm_.SetGlobalBuffer((__gm__ T *)y + perCoreRowCount_ * k_ * blockIdx_, k_); - expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx + perCoreRowCount_ * k_ * blockIdx_, k_); - outGm_.SetGlobalBuffer((__gm__ float *)out + perCoreRowCount_ * expertCount_ * blockIdx_, expertCount_); - - // init que - pipe_->InitBuffer(xInQueue_, 1, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(yOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(float)); - pipe_->InitBuffer(expertIdxOutQueue_, 1, Align(k_, sizeof(float)) * sizeof(int32_t)); - pipe_->InitBuffer(outOutQueue_, 1, expertCountAlign_ * sizeof(float)); - - // init calc buf - pipe_->InitBuffer(biasBuf_, expertCountAlign_ * sizeof(float) * (sizeof(float) / sizeof(T))); - pipe_->InitBuffer(expertIdBuf_, expertCountAlign_ * sizeof(int32_t)); - pipe_->InitBuffer(xNormBuf_, expertCountAlign_ * sizeof(float)); - pipe_->InitBuffer(xNormWithBiasBuf_, expertCountAlign_ * sizeof(float)); - pipe_->InitBuffer(topKExpertIdBuf_, Align(k_, sizeof(float)) * sizeof(int32_t)); - - // init tmp buf - pipe_->InitBuffer(calcTmpBuf_, expertCountAlign_ * sizeof(float) * CONSTANT_EIGHT); -} - -template -__aicore__ inline void MoeGatingTopKWithoutGroup::Process() -{ - CopyInBiasAndInitExpertId(); - for (int64_t row = 0; row < curCoreRowCount_; row++) { - CopyInX(row); - ComputeX(); - if (outFlag_) { - CopuOutXNorm(row); - } - SelectTopKExpertIdx(); - SelectTopKExpertScore(); - CopyOut(row); - } -} -} // namespace MoeGatingTopK -#endif // MOE_GATING_TOP_K_E_K_WITHOUT_GROUP_H \ No newline at end of file diff --git a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h deleted file mode 100644 index f7e7fce1..00000000 --- a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file data_copy_transpose_tiling.h - * \brief - */ - -#pragma once - -#include -#include -#include "data_copy_transpose_tiling_def.h" - -namespace optiling { - -inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize, - optiling::CopyTransposeTiling &tiling) -{ - constexpr int64_t B_INDEX = 0; - constexpr int64_t N_INDEX = 1; - constexpr int64_t S_INDEX = 2; - constexpr int64_t H_INDEX = 3; - std::vector dstShapeInfo = dstShape.GetDims(); - std::vector srcShapeInfo = srcShape.GetDims(); - - tiling.set_dstShapeB(dstShapeInfo[B_INDEX]); - tiling.set_dstShapeN(dstShapeInfo[N_INDEX]); - tiling.set_dstShapeS(dstShapeInfo[S_INDEX]); - tiling.set_dstShapeH(dstShapeInfo[H_INDEX]); - tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN()); - - tiling.set_srcShapeB(srcShapeInfo[B_INDEX]); - tiling.set_srcShapeN(srcShapeInfo[N_INDEX]); - tiling.set_srcShapeS(srcShapeInfo[S_INDEX]); - tiling.set_srcShapeHN(srcShapeInfo[H_INDEX]); - tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize); - tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH()); - tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS()); - tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN()); - tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH()); -} - -} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h b/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h deleted file mode 100644 index 18552c36..00000000 --- a/csrc/moe_gating_top_k/tiling_base/data_copy_transpose_tiling_def.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file data_copy_transpose_tiling_def.h - * \brief - */ - -#pragma once - -#include -#include - -namespace optiling { - -BEGIN_TILING_DATA_DEF(CopyTransposeTiling) -TILING_DATA_FIELD_DEF(uint32_t, dstShapeB); -TILING_DATA_FIELD_DEF(uint32_t, dstShapeN); -TILING_DATA_FIELD_DEF(uint32_t, dstShapeS); -TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN); -TILING_DATA_FIELD_DEF(uint32_t, dstShapeH); -TILING_DATA_FIELD_DEF(uint32_t, srcShapeB); -TILING_DATA_FIELD_DEF(uint32_t, srcShapeN); -TILING_DATA_FIELD_DEF(uint32_t, srcShapeS); -TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN); -TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen); -TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue); -TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue); -TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue); -TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling); -TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue); -TILING_DATA_FIELD_DEF(uint32_t, paramsAlign); -END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling) - -} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/error_log.h b/csrc/moe_gating_top_k/tiling_base/error_log.h deleted file mode 100644 index 8b64eced..00000000 --- a/csrc/moe_gating_top_k/tiling_base/error_log.h +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ -#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ - -#include -#include "toolchain/slog.h" - -#define OP_LOGI(opname, ...) -#define OP_LOGW(opname, ...) \ - do { \ - printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ - do { \ - printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGE(opname, ...) \ - do { \ - printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#define OP_LOGD(opname, ...) - -namespace optiling { - -#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ - do { \ - OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ - } while (0) - -// 修改 OP_TILING_CHECK 宏,确保正确处理表达式 -#define OP_CHECK_IF(cond, log_func, expr) \ - do { \ - if (cond) { \ - log_func; \ - expr; \ - } \ - } while (0) - - - -#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ - do { \ - if ((ptr) == nullptr) { \ - OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -} // namespace optiling - -#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_base.h b/csrc/moe_gating_top_k/tiling_base/tiling_base.h deleted file mode 100644 index 3dbaab7a..00000000 --- a/csrc/moe_gating_top_k/tiling_base/tiling_base.h +++ /dev/null @@ -1,256 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file tiling_base.h - * \brief - */ - -#pragma once - -#include -#include -#include -#include "tiling/platform/platform_ascendc.h" -#include "error_log.h" - -#ifdef ASCENDC_OP_TEST -#define ASCENDC_EXTERN_C extern "C" -#else -#define ASCENDC_EXTERN_C -#endif - -namespace Ops { -namespace Transformer { -namespace OpTiling { - -struct AiCoreParams { - uint64_t ubSize = 0; - uint64_t blockDim = 0; - uint64_t aicNum = 0; - uint64_t l1Size = 0; - uint64_t l0aSize = 0; - uint64_t l0bSize = 0; - uint64_t l0cSize = 0; -}; - -struct CompileInfoCommon { - uint32_t aivNum; - uint32_t aicNum; - uint64_t ubSize; - uint64_t l1Size; - uint64_t l0aSize; - uint64_t l0bSize; - uint64_t l0cSize; - uint64_t l2CacheSize; - int64_t coreNum; - int32_t socVersion; - uint32_t rsvd; -}; - -struct FlashAttentionScoreGradCompileInfo { - uint32_t aivNum; - uint32_t aicNum; - uint64_t ubSize; - uint64_t l1Size; - uint64_t l0aSize; - uint64_t l0bSize; - uint64_t l0cSize; - uint64_t l2CacheSize; - int64_t coreNum; - platform_ascendc::SocVersion socVersion; -}; - -struct FACompileInfoCommon { - uint32_t aivNum; - uint32_t aicNum; - uint64_t ubSize; - uint64_t l1Size; - uint64_t l0aSize; - uint64_t l0bSize; - uint64_t l0cSize; - uint64_t l2CacheSize; - int64_t coreNum; - int32_t socVersion; - uint32_t rsvd; -}; - -class TilingBaseClass { -public: - explicit TilingBaseClass(gert::TilingContext* context) : context_(context) - {} - - virtual ~TilingBaseClass() = default; - - // Tiling执行框架 - // 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现 - // 2、GRAPH_FAILED: 失败,中止整个Tiling流程 - // 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现 - ge::graphStatus DoTiling() - { - auto ret = GetShapeAttrsInfo(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - ret = GetPlatformInfo(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - if (!IsCapable()) { - return ge::GRAPH_PARAM_INVALID; - } - ret = DoOpTiling(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - ret = DoLibApiTiling(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - ret = GetWorkspaceSize(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - ret = PostTiling(); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - context_->SetTilingKey(GetTilingKey()); - DumpTilingInfo(); - return ge::GRAPH_SUCCESS; - } - - // 更新 context - virtual void Reset(gert::TilingContext* context) - { - context_ = context; - } - -protected: - virtual bool IsCapable() = 0; - // 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小 - virtual ge::graphStatus GetPlatformInfo() = 0; - // 2、获取INPUT/OUTPUT/ATTR信息 - virtual ge::graphStatus GetShapeAttrsInfo() = 0; - // 3、计算数据切分TilingData - virtual ge::graphStatus DoOpTiling() = 0; - // 4、计算高阶API的TilingData - virtual ge::graphStatus DoLibApiTiling() = 0; - // 5、计算TilingKey - [[nodiscard]] virtual uint64_t GetTilingKey() const = 0; - // 6、计算Workspace 大小 - virtual ge::graphStatus GetWorkspaceSize() = 0; - // 7、保存Tiling数据 - virtual ge::graphStatus PostTiling() = 0; - // 8、Dump Tiling数据 - virtual void DumpTilingInfo() - { - int32_t enable = CheckLogLevel(static_cast(OP), DLOG_DEBUG); - if (enable != 1) { - return; - } - auto buf = (uint32_t*)context_->GetRawTilingData()->GetData(); - auto bufLen = context_->GetRawTilingData()->GetDataSize(); - std::ostringstream oss; - oss << "Start to dump tiling info. tilingkey:" << context_->GetTilingKey() << ", tiling data size:" << bufLen - << ", content:"; - for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) { - oss << *(buf + i) << ","; - if (oss.str().length() > 640) { // Split according to 640 to avoid truncation - OP_LOGD(context_, "%s", oss.str().c_str()); - oss.str(""); - } - } - OP_LOGD(context_, "%s", oss.str().c_str()); - } - - static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) - { - uint32_t ration; - if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) { - return sliceNum; - } - ration = aivCoreNum / aicCoreNum; - return (sliceNum + (ration - 1)) / ration; - } - - template - [[nodiscard]] std::string GetShapeDebugStr(const T& shape) const - { - std::ostringstream oss; - oss << "["; - if (shape.GetDimNum() > 0) { - for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { - oss << shape.GetDim(i) << ", "; - } - oss << shape.GetDim(shape.GetDimNum() - 1); - } - oss << "]"; - return oss.str(); - } - - [[nodiscard]] std::string GetTensorDebugStr( - const gert::StorageShape* shape, const gert::CompileTimeTensorDesc* tensor) - { - if (shape == nullptr || tensor == nullptr) { - return "nil "; - } - std::ostringstream oss; - oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),"; - oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),"; - oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),"; - oss << "(format: " - << ge::TypeUtils::FormatToSerialString( - static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) - << "),"; - oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") "; - return oss.str(); - } - - [[nodiscard]] std::string GetTilingContextDebugStr() - { - std::ostringstream oss; - for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) { - oss << "input" << i << ": "; - oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i)); - } - - for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) { - oss << "output" << i << ": "; - oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i)); - } - return oss.str(); - } - - [[nodiscard]] std::string GetTilingDataDebugStr() const - { - auto rawTilingData = context_->GetRawTilingData(); - auto rawTilingDataSize = rawTilingData->GetDataSize(); - auto data = reinterpret_cast(rawTilingData->GetData()); - size_t len = rawTilingDataSize / sizeof(int32_t); - std::ostringstream oss; - for (size_t i = 0; i < len; i++) { - oss << data[i] << ", "; - } - return oss.str(); - } - -protected: - gert::TilingContext* context_ = nullptr; - std::unique_ptr ascendcPlatform_{nullptr}; - uint32_t blockDim_{0}; - uint64_t workspaceSize_{0}; - uint64_t tilingKey_{0}; - AiCoreParams aicoreParams_; -}; - -} // namespace OpTiling -} // namespace Transformer -} // namespace Ops \ No newline at end of file diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_key.h b/csrc/moe_gating_top_k/tiling_base/tiling_key.h deleted file mode 100644 index cbaaa5e8..00000000 --- a/csrc/moe_gating_top_k/tiling_base/tiling_key.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file tiling_key.h - * \brief - */ - -#pragma once - -#include - -namespace Ops { -namespace Transformer { -namespace OpTiling { -constexpr uint64_t RecursiveSum() -{ - return 0; -} - -constexpr uint64_t kBase = 10; // 10进制进位基数 -template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) -{ - return static_cast(templateId) + kBase * RecursiveSum(templateIds...); -} - -// TilingKey 的生成规则: -// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1, -// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1: -// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分, -// 那么填AXIS_NONE。UB0和UB1各占一个十进制位; -// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位; -// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位 -// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位 -// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位 -// 其余特化场景,定义自己的位域和值 -// usage: get tilingKey from inputed types -// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, -// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) - -constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 -template constexpr uint64_t GET_TILINGKEY(Args... templateIds) -{ - return TILINGKEYOFFSET + RecursiveSum(templateIds...); -} - -// usage: get tilingKey from inputed types -// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) - -#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ - (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ - SparseEnum::sparse)) - -} // namespace Optiling -} // namespace Transformer -} // namespace Ops diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h b/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h deleted file mode 100644 index a9f3dc5f..00000000 --- a/csrc/moe_gating_top_k/tiling_base/tiling_templates_registry.h +++ /dev/null @@ -1,351 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file tiling_templates_registry.h - * \brief - */ - -#pragma once - -#include -#include -#include -#include "exe_graph/runtime/tiling_context.h" -#include "tiling_base.h" -#include "error_log.h" - -namespace Ops { -namespace Transformer { -namespace OpTiling { - -template -std::unique_ptr TILING_CLASS(gert::TilingContext* context) -{ - return std::unique_ptr(new (std::nothrow) T(context)); -} - -using TilingClassCase = std::unique_ptr (*)(gert::TilingContext*); - -class TilingCases { -public: - explicit TilingCases(std::string op_type) : op_type_(std::move(op_type)) - {} - - template - void AddTiling(int32_t priority) - { - OP_CHECK_IF( - cases_.find(priority) != cases_.end(), OP_LOGE(op_type_, "There are duplicate registrations."), return); - cases_[priority] = TILING_CLASS; - OP_CHECK_IF( - cases_[priority] == nullptr, - OP_LOGE(op_type_, "Register op tiling func failed, please check the class name."), return); - } - - const std::map& GetTilingCases() - { - return cases_; - } - -private: - std::map cases_; - const std::string op_type_; -}; - -// --------------------------------Interfacce with soc version -------------------------------- -class TilingRegistryNew { -public: - TilingRegistryNew() = default; - -#ifdef ASCENDC_OP_TEST - static TilingRegistryNew& GetInstance(); -#else - static TilingRegistryNew& GetInstance() - { - static TilingRegistryNew registry_impl_; - return registry_impl_; - } -#endif - - std::shared_ptr RegisterOp(const std::string& op_type, int32_t soc_version) - { - auto soc_iter = registry_map_.find(soc_version); - if (soc_iter == registry_map_.end()) { - std::map> op_type_map; - op_type_map[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); - registry_map_[soc_version] = op_type_map; - } else { - if (soc_iter->second.find(op_type) == soc_iter->second.end()) { - soc_iter->second[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); - } - } - - OP_CHECK_IF( - registry_map_[soc_version][op_type] == nullptr, - OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); - return registry_map_[soc_version][op_type]; - } - - ge::graphStatus DoTilingImpl(gert::TilingContext* context) - { - int32_t soc_version = (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION; - const char* op_type = context->GetNodeType(); - fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); - if (platformInfoPtr == nullptr) { - auto compileInfoPtr = static_cast(context->GetCompileInfo()); - OP_CHECK_IF( - compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); - soc_version = compileInfoPtr->socVersion; - OP_LOGD(context, "soc version in compileInfo is %d", soc_version); - } else { - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); - soc_version = static_cast(ascendcPlatform.GetSocVersion()); - OP_LOGD(context, "soc version is %d", soc_version); - if (soc_version == (int32_t)platform_ascendc::SocVersion::RESERVED_VERSION) { - OP_LOGE(op_type, "Do op tiling failed, cannot find soc version."); - return ge::GRAPH_FAILED; - } - } - auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); - for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { - auto tilingTemplate = it->second(context); - if (tilingTemplate != nullptr) { - ge::graphStatus status = tilingTemplate->DoTiling(); - if (status != ge::GRAPH_PARAM_INVALID) { - OP_LOGD(context, "Do general op tiling success priority=%d", it->first); - return status; - } - OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); - } - } - OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); - return ge::GRAPH_FAILED; - } - - ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) - { - int32_t soc_version; - const char* op_type = context->GetNodeType(); - auto platformInfoPtr = context->GetPlatformInfo(); - if (platformInfoPtr == nullptr) { - auto compileInfoPtr = reinterpret_cast(context->GetCompileInfo()); - OP_CHECK_IF( - compileInfoPtr == nullptr, OP_LOGE(op_type, "compileInfoPtr is null."), return ge::GRAPH_FAILED); - soc_version = compileInfoPtr->socVersion; - OP_LOGD(context, "soc version in compileInfo is %d", soc_version); - } else { - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); - soc_version = static_cast(ascendcPlatform.GetSocVersion()); - OP_LOGD(context, "soc version is %d", soc_version); - } - - auto tilingTemplateRegistryMap = GetTilingTemplates(op_type, soc_version); - for (auto priority_id : priorities) { - auto tilingCaseIter = tilingTemplateRegistryMap.find(priority_id); - if (tilingCaseIter != tilingTemplateRegistryMap.end()) { - auto templateFunc = tilingCaseIter->second(context); - if (templateFunc != nullptr) { - ge::graphStatus status = templateFunc->DoTiling(); - if (status == ge::GRAPH_SUCCESS) { - OP_LOGD(context, "Do general op tiling success priority=%d", priority_id); - return status; - } - OP_LOGD(context, "Ignore general op tiling priority=%d", priority_id); - } - } - } - return ge::GRAPH_FAILED; - } - - const std::map& GetTilingTemplates(const std::string& op_type, int32_t soc_version) - { - auto soc_iter = registry_map_.find(soc_version); - OP_CHECK_IF( - soc_iter == registry_map_.end(), - OP_LOGE(op_type, "Get op tiling func failed, please check the soc version %d", soc_version), - return empty_tiling_case_); - auto op_iter = soc_iter->second.find(op_type); - OP_CHECK_IF( - op_iter == soc_iter->second.end(), OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), - return empty_tiling_case_); - return op_iter->second->GetTilingCases(); - } - -private: - std::map>> registry_map_; // key is socversion - const std::map empty_tiling_case_{}; -}; - -class RegisterNew { -public: - explicit RegisterNew(std::string op_type) : op_type_(std::move(op_type)) - {} - - template - RegisterNew& tiling(int32_t priority, int32_t soc_version) - { - auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); - OP_CHECK_IF( - tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); - tilingCases->AddTiling(priority); - return *this; - } - - template - RegisterNew& tiling(int32_t priority, const std::vector& soc_versions) - { - for (int32_t soc_version : soc_versions) { - auto tilingCases = TilingRegistryNew::GetInstance().RegisterOp(op_type_, soc_version); - OP_CHECK_IF( - tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), - return *this); - tilingCases->AddTiling(priority); - } - return *this; - } - -private: - const std::string op_type_; -}; - -// --------------------------------Interfacce without soc version -------------------------------- -class TilingRegistry { -public: - TilingRegistry() = default; - -#ifdef ASCENDC_OP_TEST - static TilingRegistry& GetInstance(); -#else - static TilingRegistry& GetInstance() - { - static TilingRegistry registry_impl_; - return registry_impl_; - } -#endif - - std::shared_ptr RegisterOp(const std::string& op_type) - { - if (registry_map_.find(op_type) == registry_map_.end()) { - registry_map_[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); - } - OP_CHECK_IF( - registry_map_[op_type] == nullptr, - OP_LOGE(op_type, "Register tiling func failed, please check the class name."), return nullptr); - return registry_map_[op_type]; - } - - ge::graphStatus DoTilingImpl(gert::TilingContext* context) - { - const char* op_type = context->GetNodeType(); - auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); - for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { - auto tilingTemplate = it->second(context); - if (tilingTemplate != nullptr) { - ge::graphStatus status = tilingTemplate->DoTiling(); - if (status != ge::GRAPH_PARAM_INVALID) { - OP_LOGD(context, "Do general op tiling success priority=%d", it->first); - return status; - } - OP_LOGD(context, "Ignore general op tiling priority=%d", it->first); - } - } - OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); - return ge::GRAPH_FAILED; - } - - ge::graphStatus DoTilingImpl(gert::TilingContext* context, const std::vector& priorities) - { - const char* op_type = context->GetNodeType(); - auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); - for (auto priorityId : priorities) { - auto templateFunc = tilingTemplateRegistryMap[priorityId](context); - if (templateFunc != nullptr) { - ge::graphStatus status = templateFunc->DoTiling(); - if (status == ge::GRAPH_SUCCESS) { - OP_LOGD(context, "Do general op tiling success priority=%d", priorityId); - return status; - } - if (status != ge::GRAPH_PARAM_INVALID) { - OP_LOGD(context, "Do op tiling failed"); - return status; - } - OP_LOGD(context, "Ignore general op tiling priority=%d", priorityId); - } - } - OP_LOGE(op_type, "Do op tiling failed, no valid template is found."); - return ge::GRAPH_FAILED; - } - - const std::map& GetTilingTemplates(const std::string& op_type) - { - OP_CHECK_IF( - registry_map_.find(op_type) == registry_map_.end(), - OP_LOGE(op_type, "Get op tiling func failed, please check the op name."), return empty_tiling_case_); - return registry_map_[op_type]->GetTilingCases(); - } - -private: - std::map> registry_map_; - const std::map empty_tiling_case_; -}; - -class Register { -public: - explicit Register(std::string op_type) : op_type_(std::move(op_type)) - {} - - template - Register& tiling(int32_t priority) - { - auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_); - OP_CHECK_IF( - tilingCases == nullptr, OP_LOGE(op_type_, "Register op tiling failed, please the op name."), return *this); - tilingCases->AddTiling(priority); - return *this; - } - -private: - const std::string op_type_; -}; -} // namespace OpTiling -} // namespace Transformer -} // namespace Ops - -// op_type: 算子名称, class_name: 注册的 tiling 类, soc_version:芯片版本号 -// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类 -#define REGISTER_TILING_TEMPLATE_WITH_SOCVERSION(op_type, class_name, soc_versions, priority) \ - [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ - static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ - Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_versions) - -// op_type: 算子名称, class_name: 注册的 tiling 类, -// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大 -#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \ - [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ - static Ops::Transformer::OpTiling::Register VAR_UNUSED##op_type_##class_name##priority_register = \ - Ops::Transformer::OpTiling::Register(op_type).tiling(priority) - -// op_type: 算子名称, class_name: 注册的 tiling 类, -// soc_version: soc版本,用于区分不同的soc -// priority: tiling 类的优先级, 越小表示优先级越高, 即会优先选择这个tiling类 -#define REGISTER_TILING_TEMPLATE_NEW(op_type, class_name, soc_version, priority) \ - [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ - static Ops::Transformer::OpTiling::RegisterNew VAR_UNUSED##op_type##class_name##priority_register = \ - Ops::Transformer::OpTiling::RegisterNew(#op_type).tiling(priority, soc_version) - -// op_type: 算子名称, class_name: 注册的 tiling 类, -// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大 -// 取代 REGISTER_TILING_TEMPLATE , 传入的op_type如果是字符串常量,需要去掉引号 -#define REGISTER_OPS_TILING_TEMPLATE(op_type, class_name, priority) \ - [[maybe_unused]] uint32_t op_impl_register_template_##op_type##_##class_name##priority; \ - static Ops::Transformer::OpTiling::Register \ - __attribute__((unused)) tiling_##op_type##_##class_name##_##priority##_register = \ - Ops::Transformer::OpTiling::Register(#op_type).tiling(priority) diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_type.h b/csrc/moe_gating_top_k/tiling_base/tiling_type.h deleted file mode 100644 index fd25c639..00000000 --- a/csrc/moe_gating_top_k/tiling_base/tiling_type.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file tiling_type.h - * \brief - */ - -#pragma once - -#include - -namespace optiling { - -enum class AxisEnum { - B = 0, - N2 = 1, - G = 2, - S1 = 3, - S2 = 4, - D = 5, - NONE = 9, -}; - -enum class DtypeEnum { - FLOAT16 = 0, - FLOAT32 = 1, - BFLOAT16 = 2, - FLOAT16_PRECISION = 3, -}; - -enum class PerformanceOrientedEnum { - BIG_BUFFER = 1, - BIG_DOUBLE_BUFFER = 2, -}; - -enum class MatmulConfig { - NULL_CONFIG = 0, - NORMAL_CONFIG = 1, - MDL_CONFIG = 2 -}; - -enum class PseConfig { - NO_PSE = 0, - EXIST_PSE = 1 -}; - -enum class AttenMaskConfig { - NO_ATTEN_MASK = 0, - EXIST_ATTEN_MASK = 1 -}; - -enum class DropOutConfig { - NO_DROP_OUT = 0, - EXIST_DROP_OUT = 1 -}; - -enum class CubeFormatEnum { - ND = 0, - NZ = 1 -}; -enum class LayoutEnum { - BSND = 0, - SBND = 1, - BNSD = 2, - TND = 3, - NTD_TND = 4 -}; - -enum class CubeInputSourceEnum { - GM = 0, - L1 = 1 -}; - -enum class OptionEnum { - DISABLE = 0, - ENABLE = 1 -}; - -enum class SparseEnum { - ALL = 0, - NONE = 1, - ANY = 2, - CAUSAL = 3, - BAND = 4, - PREFIX = 5, - BAND_COMPRESS = 6, - RIGHT_DOWN_CAUSAL = 7, - RIGHT_DOWN_CAUSAL_BAND = 8, - BAND_LEFT_UP_CAUSAL = 9 -}; - -constexpr uint64_t RecursiveSum() -{ - return 0; -} - -constexpr int64_t base10Multiplier = 10; - -template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) -{ - return static_cast(templateId) + base10Multiplier * RecursiveSum(templateIds...); -} - -// TilingKey 的生成规则: -// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1, -// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1: -// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分, -// 那么填AXIS_NONE。UB0和UB1各占一个十进制位; -// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位; -// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位 -// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位 -// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位 -// 其余特化场景,定义自己的位域和值 -// usage: get tilingKey from inputed types -// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, -// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) - -constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 -template constexpr uint64_t GET_TILINGKEY(Args... templateIds) -{ - return TILINGKEYOFFSET + RecursiveSum(templateIds...); -} - -// usage: get tilingKey from inputed types -// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) - -#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ - (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ - SparseEnum::sparse)) - -} // namespace optiling diff --git a/csrc/moe_gating_top_k/tiling_base/tiling_util.h b/csrc/moe_gating_top_k/tiling_base/tiling_util.h deleted file mode 100644 index fb6ffa2d..00000000 --- a/csrc/moe_gating_top_k/tiling_base/tiling_util.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file tiling_util.h - * \brief - */ - -#pragma once - -#include "register/op_impl_registry.h" - -namespace Ops { -namespace Transformer { -namespace OpTiling { -bool IsRegbaseSocVersion(const gert::TilingParseContext* context); - -bool IsRegbaseSocVersion(const gert::TilingContext* context); - -const gert::Shape& EnsureNotScalar(const gert::Shape& inShape); -} // namespace OpTiling -} // namespace Transformer -} // namespace Ops \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 2a31e235..a8077b67 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1118,60 +1118,6 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons return combined_x; } -std::tuple moe_gating_top_k( - const at::Tensor& x, - int64_t k, - int64_t kGroup, - int64_t groupCount, - int64_t groupSelectMode, - int64_t renorm, - int64_t normType, - bool outFlag, - double routedScalingFactor, - double eps, - const c10::optional& biasOptional - ) -{ - TORCH_CHECK(x.dim() == 2, "The x should be 2D"); - TORCH_CHECK( - x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, - "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", - x.scalar_type()); - - auto x_size = x.sizes(); - auto rows = x_size[0]; - auto expert_num = x_size[1]; - const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); }); - if (bias.defined()) { - TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); - TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); - auto bias_size = bias.sizes(); - TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); - } - at::Tensor yOut = at::empty({rows, k}, x.options()); - at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt)); - at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); - - EXEC_NPU_CMD(aclnnMoeGatingTopK, - x, // input_x - biasOptional, - k, // k - kGroup, // k_group - groupCount, // group_count - groupSelectMode, // group_select_mode - renorm, // renorm - normType, // norm_type - outFlag, // out_flag - routedScalingFactor, // routed_scaling_factor - eps, // eps - yOut, // input_y (注意:这里应该是 yOut) - expertIdxOut, // output - outOut - ); - - return std::tuple(outOut,expertIdxOut, yOut); -} - std::tuple npu_moe_init_routing_custom( const at::Tensor &x, const at::Tensor &expert_idx, const c10::optional &scale, const c10::optional &offset, int64_t active_num, @@ -1275,25 +1221,8 @@ std::tuple npu_moe_init_routing_ } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) -{ +{ // vLLM-Ascend custom ops - ops.def( - "moe_gating_top_k(Tensor x, " - "int k, " - "int kGroup, " - "int groupCount, " - "int groupSelectMode, " - "int renorm, " - "int normType, " - "bool outFlag, " - "float routedScalingFactor, " - "float eps," - "Tensor? biasOptional=None)" - - "-> (Tensor outOut,Tensor expertIdxOut, Tensor yOut)" - ); - ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k); - //Moe_gating ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index afd6a680..af2c237a 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -283,42 +283,6 @@ std::tuple matmul_allreduce_add_rmsnorm_meta( return {output, add_out}; } -std::tuple moe_gating_top_k_meta( - const at::Tensor& x, - int64_t k, - int64_t kGroup, - int64_t groupCount, - int64_t groupSelectMode, - int64_t renorm, - int64_t normType, - bool outFlag, - double routedScalingFactor, - double eps, - const c10::optional& biasOptional) -{ - TORCH_CHECK(x.dim() == 2, "The x should be 2D"); - TORCH_CHECK( - x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16, - "float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ", - x.scalar_type()); - - auto x_size = x.sizes(); - auto rows = x_size[0]; - auto expert_num = x_size[1]; - const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); }); - if (bias.defined()) { - TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same"); - TORCH_CHECK(bias.dim() == 1, "The bias should be 1D"); - auto bias_size = bias.sizes(); - TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim"); - } - at::Tensor yOut = at::empty({rows, k}, x.options()); - at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt)); - at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat)); - - return std::tuple(outOut,expertIdxOut, yOut); -} - std::tuple npu_moe_init_routing_custom_meta( const at::Tensor &x, const at::Tensor &expert_idx, const c10::optional &scale, const c10::optional &offset, int64_t active_num, @@ -403,15 +367,12 @@ std::tuple npu_moe_init_routing_ } } // namespace meta - } // namespace vllm_ascend namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { - // Moe_gating_top_k - ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta); // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation diff --git a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py index 7760e532..03b4b964 100644 --- a/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py @@ -179,8 +179,7 @@ class SmallOps(DecodeMoeOps): shared_expert_rank_num=self.shared_expert_rank_num, quant_mode=2, global_bs=self.batch_size * self.ep_world_size, - expert_token_nums_type= - 1, # 0 represents prefix sum, 1 represents individual counts + expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 ) expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs output_dtype = x.dtype @@ -189,8 +188,8 @@ class SmallOps(DecodeMoeOps): x=[expand_x], weight=[self.gmm1_weight], split_item=3, - group_list_type=1, # Default is 0, represents prefix sum format - group_type=0, # 0 represents m-axis grouping + group_list_type=1, # 默认为0,代表前缀和形式 + group_type=0, # 0代表m轴分组 group_list=expert_token_nums, output_dtype=torch.int32)[0] y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( @@ -366,7 +365,7 @@ def run_once(local_rank_id, with_mc2_mask=False): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None - global_rank_id = local_rank_id # Single machine + global_rank_id = local_rank_id # 单机 device_id = local_rank_id % 16 torch_npu.npu.set_device(device_id) diff --git a/tests/e2e/nightly/ops/test_npu_moe_gating_top_k.py b/tests/e2e/nightly/ops/test_npu_moe_gating_top_k.py deleted file mode 100644 index f0a64748..00000000 --- a/tests/e2e/nightly/ops/test_npu_moe_gating_top_k.py +++ /dev/null @@ -1,322 +0,0 @@ -import itertools -import logging -import random -from typing import Optional, Tuple - -import numpy as np -import torch -from torch_npu.testing.testcase import TestCase, run_tests - -try: - from vllm_ascend.utils import enable_custom_op - enable_custom_op() -except ImportError: - logging.warning( - "vllm_ascend.utils.enable_custom_op not found, skip custom op initialization" - ) - - def enable_custom_op() -> None: - pass - - -# Set random seed for reproducibility -SEED = 45 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -if hasattr(torch, "npu") and torch.npu.is_available(): - torch.npu.manual_seed_all(SEED) - -# Configure logging -logging.basicConfig(level=logging.INFO, - format="[%(asctime)s] %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S") -logger = logging.getLogger(__name__) - - -def softmax_func( - x: np.ndarray, - axis: Optional[int] = None, - eps: float = 1e-20) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Stable softmax implementation for MOE gating. - - Args: - x: Input array - axis: Axis to compute softmax - eps: Epsilon to avoid division by zero - - Returns: - softmax_output: Softmax result - x_max: Max value for numerical stability - x_sum: Sum of exponentials - """ - if "float16" in x.dtype.name: - x = x.astype(np.float32) - - x_max = x.max(axis=axis, keepdims=True) - x_sub = x - x_max - y = np.exp(x_sub) - x_sum = y.sum(axis=axis, keepdims=True) - softmax_output = y / (x_sum + eps) - - return softmax_output, x_max, x_sum - - -class TestNpuMoeGatingTopK(TestCase): - """Test suite for NPU MOE Gating Top-K operator compatibility.""" - - def moe_gating_top_k_np( - self, - x: np.ndarray, - k: int, - bias: Optional[np.ndarray] = None, - k_group: int = 1, - group_count: int = 1, - group_select_mode: int = 0, - renorm: int = 0, - norm_type: int = 0, - y2_flag: bool = False, - routed_scaling_factor: float = 1.0, - eps: float = 1e-20 - ) -> Tuple[torch.Tensor, np.ndarray, Optional[np.ndarray]]: - """ - NumPy reference implementation of MOE gating Top-K logic. - - Args: - x: Input features, shape [batch_size, num_experts] - k: Number of experts to select per sample - bias: Gating bias, shape [num_experts] - k_group: Number of groups to select (group mode) - group_count: Number of expert groups - group_select_mode: 0 (max per group), 1 (sum of top-2 per group) - renorm: Whether to renormalize weights (1=enable, 0=disable) - norm_type: 0 (softmax), 1 (sigmoid) - y2_flag: Whether to return original x as y2 - routed_scaling_factor: Weight scaling factor - eps: Epsilon for numerical stability - - Returns: - y: Selected expert weights (Tensor) - indices: Selected expert indices (int32 numpy array) - y2: Original x if y2_flag=True, else None - """ - # Convert torch tensors to numpy arrays if needed (compatibility layer) - if isinstance(x, torch.Tensor): - x = x.cpu().numpy() - if isinstance(bias, torch.Tensor): - bias = bias.cpu().numpy() - - # Type conversion for numerical stability - orig_dtype = x.dtype - if orig_dtype != np.float32: - x = x.astype(np.float32) - if bias is not None: - bias = bias.astype(np.float32) - - # Apply normalization (softmax/sigmoid) - if norm_type == 0: - x, _, _ = softmax_func(x, axis=-1, eps=eps) - else: - x = 1 / (1 + np.exp(-x)) # Sigmoid - - original_x = x.copy() - - # Apply bias if provided - if bias is not None: - x = x + bias - - # Group-based expert selection - if group_count > 1: - batch_size, num_experts = x.shape - if num_experts % group_count != 0: - raise ValueError( - f"num_experts ({num_experts}) must be divisible by group_count ({group_count})" - ) - group_size = num_experts // group_count - - # Reshape to [batch, groups, group_size] - x_reshaped = x.reshape(batch_size, group_count, group_size) - - # Compute group scores - if group_select_mode == 0: - group_scores = np.amax(x_reshaped, axis=-1) - else: - # Sum of top-2 values per group - group_scores = np.partition(x_reshaped, -2, - axis=-1)[..., -2:].sum(axis=-1) - - # Select top-k_group groups - top_groups = np.argsort(-group_scores, axis=-1, - kind="stable")[:, :k_group] - - # Mask out non-selected groups with -inf - mask = np.ones((batch_size, group_count), dtype=bool) - mask[np.arange(batch_size)[:, None], top_groups] = False - x_reshaped = np.where(mask[..., None], float("-inf"), x_reshaped) - - # Reshape back to original - x = x_reshaped.reshape(batch_size, num_experts) - - # Select top-k experts - x_tensor = torch.from_numpy(x) - _, topk_indices = torch.sort(x_tensor, - dim=-1, - stable=True, - descending=True) - topk_indices = np.asarray(topk_indices[:, :k], dtype=np.int32) - - # Extract weights for selected experts - selected_weights = np.take_along_axis(original_x, topk_indices, axis=1) - - # Apply renormalization if needed - if norm_type == 1 or renorm == 1: - weight_sum = np.sum(selected_weights, axis=-1, keepdims=True) - selected_weights = selected_weights / (weight_sum + eps) - - # Apply scaling factor - selected_weights *= routed_scaling_factor - - # Prepare y2 output - y2 = original_x if y2_flag else None - - # Convert back to torch tensor with original dtype - selected_weights_tensor = torch.tensor(selected_weights, - dtype=orig_dtype) - - return selected_weights_tensor, topk_indices, y2 - - def test_npu_moe_gating_topk_multi(self) -> None: - """ - Multi-case test for NPU MOE Gating Top-K operator. - Validates compatibility with different input shapes and parameter combinations. - """ - # Test parameter space (aligned with vllm-ascend use cases) - test_configs = { - "group_select_modes": [0, 1], - "renorms": [1], - "norm_types": [0, 1], - "group_counts": [1, 8], - "k_ranges": [4, 8, 12, 16, 6, 32], - "x_dim0": range(1, 17), # Batch size 1-16 - "x_dim1": [256, 128, 64, 208, 192, 160] # Expert counts - } - - # Generate parameter combinations - param_combinations = itertools.product( - test_configs["group_select_modes"], test_configs["renorms"], - test_configs["norm_types"], test_configs["group_counts"], - test_configs["k_ranges"], test_configs["x_dim0"], - test_configs["x_dim1"]) - - # Limit test cases to avoid excessive runtime (adjust as needed) - max_test_cases = 100 - tested_cases = 0 - - for params in param_combinations: - if tested_cases >= max_test_cases: - break - - (group_select_mode, renorm, norm_type, group_count, k, dim0, - dim1) = params - - # Skip invalid configurations - if group_count > 1: - if dim1 % group_count != 0: - continue - if k > (dim1 // group_count): - continue - - # Generate random inputs (consistent with vllm-ascend input distribution) - x_np = np.random.uniform(-2.0, 2.0, - (dim0, dim1)).astype(np.float32) - bias_np = np.random.uniform(-2.0, 2.0, (dim1, )).astype(np.float32) - - # Convert to torch tensors - x_tensor = torch.tensor(x_np, dtype=torch.float32) - bias_tensor = torch.tensor(bias_np, dtype=torch.float32) - - # Random k_group (within valid range) - k_group = random.randint(1, min(group_count, 4)) - - # Fixed parameters (aligned with NPU OP defaults) - y2_flag = False - routed_scaling_factor = 1.0 - eps = 1e-20 - - try: - # Get NumPy reference result - ref_weights, ref_indices, ref_y2 = self.moe_gating_top_k_np( - x=x_tensor, - k=k, - bias=bias_tensor, - k_group=k_group, - group_count=group_count, - group_select_mode=group_select_mode, - renorm=renorm, - norm_type=norm_type, - y2_flag=y2_flag, - routed_scaling_factor=routed_scaling_factor, - eps=eps) - - # Skip if NPU OP is not available - if not hasattr(torch.ops, "_C_ascend") or not hasattr( - torch.ops._C_ascend, "moe_gating_top_k"): - logger.warning( - "NPU MOE gating OP not found, skipping NPU test") - continue - - # Get NPU OP result - npu_weights, npu_indices, npu_y2 = torch.ops._C_ascend.moe_gating_top_k( - x=x_tensor.npu(), - k=k, - kGroup=k_group, - groupCount=group_count, - groupSelectMode=group_select_mode, - renorm=renorm, - normType=norm_type, - outFlag=y2_flag, - routedScalingFactor=routed_scaling_factor, - eps=eps, - biasOptional=bias_tensor.npu() - if bias_tensor is not None else None) - - # Convert NPU results to CPU for comparison - npu_weights_cpu = npu_weights.cpu() - npu_indices_cpu = npu_indices.cpu().numpy() - - # Log test case info (vllm-ascend standard format) - logger.info( - f"Test Case {tested_cases + 1}: " - f"x_shape=({dim0},{dim1}), k={k}, group_count={group_count}, " - f"select_mode={group_select_mode}, norm_type={norm_type}, renorm={renorm}" - ) - - # Validate results (RTOL=1e-3 is standard for NPU numerical tolerance) - self.assertRtolEqual(ref_weights, - npu_weights_cpu, - rtol=1e-3, - atol=1e-5) - self.assertRtolEqual(ref_indices, npu_indices_cpu) - - # Validate y2 if enabled - if y2_flag: - self.assertRtolEqual(ref_y2, - npu_y2.cpu().numpy(), - rtol=1e-3, - atol=1e-5) - - tested_cases += 1 - logger.info(f"Test Case {tested_cases} passed ") - - except Exception as e: - logger.error(f"Test Case failed with error: {str(e)}", - exc_info=True) - continue - - logger.info(f"Completed {tested_cases}/{max_test_cases} test cases") - - -if __name__ == "__main__": - # Run tests with vllm-ascend standard verbosity - run_tests(verbosity=2) diff --git a/tests/ut/model_loader/netloader/test_netloader_elastic.py b/tests/ut/model_loader/netloader/test_netloader_elastic.py index f833a129..127f1dd6 100644 --- a/tests/ut/model_loader/netloader/test_netloader_elastic.py +++ b/tests/ut/model_loader/netloader/test_netloader_elastic.py @@ -311,7 +311,7 @@ def test_client_handler_mismatch(server_config): mismatch_data = { "label": "JOIN", "content": { - "device_id": 1, # Mismatched ID + "device_id": 1, # 不匹配的ID "model_path": "/wrong/model", "tp": 2, "pp": 2, diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 4fed3168..e3f6dd14 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -670,7 +670,7 @@ class TestNPUWorker(TestBase): (5000, 10000), ] - # Create worker mock + # 创建 worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): worker = NPUWorker() worker.init_npu_memory = 8500 diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index e30d36e7..51e0cb9f 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -17,6 +17,7 @@ from typing import Callable, Optional import torch +import torch_npu from vllm_ascend.utils import get_weight_prefetch_method @@ -213,19 +214,21 @@ def _select_experts_with_fusion_ops( e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to( router_logits.dtype) - _, topk_ids, topk_weights = torch.ops._C_ascend.moe_gating_top_k( + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, k=top_k, - kGroup=topk_group, - groupCount=num_expert_group, - groupSelectMode=1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=1, # 0: softmax->topk(fix); 1: topk->softmax - normType=norm_type, # 0: softmax; 1: sigmoid - outFlag=False, # todo new api; should the third output be output - routedScalingFactor=1, - eps=float(1e-20), - biasOptional=e_score_correction_bias, - ) + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids