diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 929fdc3e..794842f5 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -82,6 +82,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "add_rms_norm_bias" "apply_top_k_top_p_custom" "transpose_kv_cache_by_block" + "moe_grouped_matmul" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/moe_grouped_matmul/op_host/CMakeLists.txt b/csrc/moe_grouped_matmul/op_host/CMakeLists.txt new file mode 100644 index 00000000..40666baa --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/CMakeLists.txt @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME MoeGroupedMatmulCustom + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(optiling PRIVATE + moe_grouped_matmul_cpu.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include/external + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof +) + +target_sources(opsproto PRIVATE + moe_grouped_matmul_infershape.cpp +) + +target_include_directories(opsproto PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include/external + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof +) + +target_sources(op_host_aclnnInner PRIVATE + moe_grouped_matmul_def.cpp +) + +target_include_directories(op_host_aclnnInner PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include/external + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof +) + +target_sources(opapi PRIVATE + moe_grouped_matmul_l0.cpp + aclnn_moe_grouped_matmul.cpp +) + +install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul.h" + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL) + +install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul_weight_nz.h" + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL) \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.cpp b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.cpp new file mode 100644 index 00000000..0878230a --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.cpp @@ -0,0 +1,248 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "aclnn_moe_grouped_matmul.h" +#include "aclnn_moe_grouped_matmul_weight_nz.h" + +#include +#include + +#include "aclnn_kernels/transdata.h" +#include "moe_grouped_matmul_l0.h" +#include "aclnn_kernels/contiguous.h" +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "opdev/common_types.h" +#include "opdev/data_type_utils.h" +#include "opdev/format_utils.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/platform.h" +#include "opdev/shape_utils.h" +#include "opdev/tensor_view_utils.h" +#include "opdev/make_op_executor.h" + +using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +namespace { + static constexpr size_t ALIGN_NZ_4BIT_N = 64UL; + static constexpr size_t ALIGN_NZ_4BIT_K = 64UL; + static constexpr size_t ALIGN_NZ_INT8_N = 32UL; + static constexpr size_t ALIGN_NZ_K = 16UL; + static constexpr size_t DIMS_THREE_FOR_GMM = 3UL; + static constexpr size_t LAST_FIRST_DIM_INDEX = 1; + static constexpr size_t LAST_SECOND_DIM_INDEX = 2; + static constexpr size_t LAST_THIRD_DIM_INDEX = 3; + + enum class GMMWeightVersion : uint32_t { + WeightNd = 1U, + WeightNz = 2U + }; + + struct MoeGroupedMatmulParams { + const aclTensorList *x = nullptr; + const aclTensorList *weight = nullptr; + const aclTensor *groupTensor = nullptr; + bool transposeWeight = false; + bool isSingleWeight = false; + GMMWeightVersion weightVersion = GMMWeightVersion::WeightNd; + const aclTensorList *y = nullptr; + DataType xDtype = DataType::DT_BF16; + }; +} + + +namespace { + +static aclnnStatus CheckNotNull(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *y) { + CHECK_COND(x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "x must not be nullptr."); + CHECK_COND(x->Size() != 0, ACLNN_ERR_PARAM_INVALID, "x must not be empty tensorlist."); + CHECK_COND(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight must not be nullptr."); + CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight must not be empty tensorlist."); + CHECK_COND(y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "y must not be nullptr."); + CHECK_COND(y->Size() != 0, ACLNN_ERR_PARAM_INVALID, "y must not be empty tensorlist."); + return ACLNN_SUCCESS; +} + +static aclnnStatus TransWeightToNzCheckAlign(MoeGroupedMatmulParams &gmmParams, const aclTensor *weight) +{ + size_t viewDimNum = weight->GetViewShape().GetDimNum(); + uint64_t k = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - 1) : + weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX); + uint64_t n = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX) : + weight->GetViewShape().GetDim(viewDimNum - 1); + bool k_align = false; + bool n_align = false; + if (weight->GetDataType() == DataType::DT_BF16 || weight->GetDataType() == DataType::DT_FLOAT16) { + k_align = k % ALIGN_NZ_K == 0; + n_align = n % ALIGN_NZ_K == 0; + } + CHECK_COND(k_align == true && n_align == true, ACLNN_ERR_PARAM_INVALID, + "When weight format is FRACTAL_NZ, weight'shape(k[%lu], n[%lu]) should be divisible by the " + "following shape: BF16/FP16[16, 16]). If the weight is transposed," + "the k/n need to be reversed.", + k, n); + return ACLNN_SUCCESS; +} + +static aclnnStatus TransWeightToNz(MoeGroupedMatmulParams &gmmParams, aclOpExecutor *executor) { + const aclTensorList *&weights = gmmParams.weight; + const aclTensorList *&x = gmmParams.x; + CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "The first tensor of x is nullptr!"); + size_t wLength = weights->Size(); + for (size_t i(0); i < wLength; ++i) { + const aclTensor* weight = (*weights)[i]; + if (weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ && + weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_16 && + weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_32) { + break; + } + TransWeightToNzCheckAlign(gmmParams, weight); + continue; + } + return ACLNN_SUCCESS; +} + +static const aclTensor *SetTensorToNZFormat(const aclTensor *input, op::Shape &shape, aclOpExecutor *executor) { + auto formatTensor = executor->CreateView(input, shape, input->GetViewOffset()); + formatTensor->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ); + formatTensor->SetOriginalFormat(input->GetViewFormat()); + formatTensor->SetViewShape(input->GetViewShape()); + return formatTensor; +} + +static aclnnStatus DataContiguous(const aclTensorList *&tensors, aclOpExecutor *executor) { + std::vector tensorsVec; + const aclTensor *contiguousTensor = nullptr; + for (size_t i = 0; i < tensors->Size(); ++i) { + const aclTensor *tensor = (*tensors)[i]; + contiguousTensor = l0op::Contiguous(tensor, executor); + CHECK_RET(contiguousTensor != nullptr, ACLNN_ERR_INNER_NULLPTR); + tensorsVec.push_back(contiguousTensor); + } + tensors = executor->AllocTensorList(tensorsVec.data(), tensorsVec.size()); + return ACLNN_SUCCESS; +} + +static aclnnStatus ParamsDataContiguous(MoeGroupedMatmulParams ¶ms, aclOpExecutor *executorPtr) { + CHECK_COND(DataContiguous(params.x, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous x failed."); // make x contiguous + DataType xDtype = (*params.x)[0]->GetDataType(); + DataType weightDtype = (*params.weight)[0]->GetDataType(); + CHECK_COND(DataContiguous(params.weight, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "Contiguous weight failed."); // make w contiguous + params.groupTensor = l0op::Contiguous(params.groupTensor, executorPtr); + CHECK_COND(params.groupTensor != nullptr, ACLNN_ERR_PARAM_INVALID, + "Contiguous groupTensor failed."); + return ACLNN_SUCCESS; +} + +static aclnnStatus GetGMMResultByL0Api(MoeGroupedMatmulParams ¶ms, uint64_t *workspaceSize, aclOpExecutor **executor) { + auto uniqueExecutor = CREATE_EXECUTOR(); // fixed written style, create OpExecutor + aclOpExecutor *executorPtr = uniqueExecutor.get(); + CHECK_RET(executorPtr != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + // op::Shape wqbmmNzShape = (*params.weight)[0]->GetStorageShape(); + + CHECK_COND(ParamsDataContiguous(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "ParamsDataContiguous failed."); + if (params.weightVersion == GMMWeightVersion::WeightNz) { + std::vector tensorsVec; + for (size_t i = 0; i < params.weight->Size(); ++i) { + const aclTensor *tensor = (*params.weight)[i]; + op::Shape weightNzShape = tensor->GetViewShape(); + tensor = SetTensorToNZFormat(tensor, weightNzShape, executorPtr); + tensorsVec.push_back(tensor); + } + params.weight = executorPtr->AllocTensorList(tensorsVec.data(), tensorsVec.size()); + } + CHECK_COND(TransWeightToNz(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, + "TransWeightToNz failed."); + // Invoke l0 operator MoeGroupedMatmul for calculation. + auto result = l0op::MoeGroupedMatmul(params.x, params.weight, + params.groupTensor, params.transposeWeight, + (*params.y)[0]->GetViewShape(), params.y->Size(), + (*params.y)[0]->GetDataType(), executorPtr); + CHECK_RET(result != nullptr, ACLNN_ERR_INNER_NULLPTR); + for (size_t i(0); i < params.y->Size(); ++i) { + auto viewCopyResult = l0op::ViewCopy((*result)[i], (*params.y)[i], executorPtr); + CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + } + // Standard syntax, get the size of workspace needed during computation. + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; +} + + +static aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(const aclTensorList *x, const aclTensorList *weight, + const aclTensor *groupList, bool transposeWeight, GMMWeightVersion weightVersion, + const aclTensorList *y, uint64_t *workspaceSize, aclOpExecutor **executor) { + DataType xDtype = DataType::DT_UNDEFINED; + for (size_t i = 0; i < x->Size(); ++i) { + if ((*x)[i] != nullptr) { + xDtype = (*x)[i]->GetDataType(); + break; + } + } + MoeGroupedMatmulParams moeGmmParams{x, weight, groupList, transposeWeight, true, weightVersion, y, xDtype}; + aclnnStatus ret = GetGMMResultByL0Api(moeGmmParams, workspaceSize, executor); + return ret; +} +} + +aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensor *groupList, bool transposeWeight, aclTensorList *out, + uint64_t *workspaceSize, aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnMoeGroupedMatmulWeightNz, + DFX_IN(x, weight, groupList), + DFX_OUT(out)); + return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNz, out, workspaceSize, executor); +} + +aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight, + const aclTensor *groupList, bool transposeWeight, aclTensorList *out, + uint64_t *workspaceSize, aclOpExecutor **executor) { + CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, + "one of required inputs is nullptr."); + // Standard syntax, Check parameters. + L2_DFX_PHASE_1(aclnnMoeGroupedMatmul, + DFX_IN(x, weight, groupList), + DFX_OUT(out)); + CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight should not be null tensorlist "); + return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNd, out, workspaceSize, executor); +} + +aclnnStatus aclnnMoeGroupedMatmul(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnMoeGroupedMatmul); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnMoeGroupedMatmulWeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnMoeGroupedMatmulWeightNz); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GMM launch aicore"); + return ACLNN_SUCCESS; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.h b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.h new file mode 100644 index 00000000..cad1eccc --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.h @@ -0,0 +1,30 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_MOE_GROUPED_MATMUL_H +#define OP_API_INC_MOE_GROUPED_MATMUL_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize( + const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList, + bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul_weight_nz.h b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul_weight_nz.h new file mode 100644 index 00000000..e1a915a7 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul_weight_nz.h @@ -0,0 +1,30 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H +#define OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize( + const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList, + bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNz(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_cpu.cpp b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_cpu.cpp new file mode 100644 index 00000000..d7bdd81c --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_cpu.cpp @@ -0,0 +1,336 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "moe_grouped_matmul_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +#define OP_LOGD(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n") +#define OP_LOGE(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n") + +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t WEIGHT_INDEX = 1; +constexpr uint32_t GROUPLIST_INDEX = 2; +namespace optiling { +constexpr uint64_t BEST_L1_PARTA = 256UL * 1024UL; +constexpr uint64_t BEST_L1_PARTB = 128UL * 1024UL; +constexpr uint64_t L1_PARTA_SIZE = 256UL * 1024UL; +constexpr int32_t BEST_BASEN = 256; +constexpr uint64_t DOUBLE_BUFFER_L0A_L0B = 2; +constexpr uint64_t DOUBLE_BUFFER_STEPKA_STEPKB = 2; +constexpr uint32_t FP32_DATATYPE_SIZE = 4; +constexpr int32_t MAX_BASEM = 256; + +static inline uint32_t SixteenAlign(uint32_t a, bool up = false) { + if (up) { + a += 15U; // 15: 16 bytes up-align + } + return a & ~15U; // ~15: 16 bytes down-align +} + +static inline int64_t SixteenAlign(int64_t a, bool up = false) { + if (up) { + a += 15; // 15: 16 bytes up-align + } + return a & ~15; // ~15: 16 bytes down-align +} + +class TilingMoeGroupedMatmulFunc { + public: + explicit TilingMoeGroupedMatmulFunc(gert::TilingContext* tiling_context) + : tiling_context_(tiling_context) {} + + ge::graphStatus Init(); + ge::graphStatus RunKernelTiling(); + + private: + MoeGroupedMatmulTilingData tiling_data_; + gert::TilingContext* tiling_context_ = nullptr; + + void SetTilingKey(); + void FillTilingData(); + ge::graphStatus CalMMTiling(); + ge::graphStatus GMMSetMMTiling(); + ge::graphStatus CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb); + ge::graphStatus DynamicTIlingSingleN(); + + void InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo); + void GMMGetPlatformInfo(); + int64_t m_ = 0L; + int64_t n_ = 0L; + int64_t k_ = 0L; + bool transpose_weight = false; + bool weight_nz = false; + uint32_t single_m_ = 0; + uint32_t single_n_ = 0; + int32_t baseM_ = 0; + int32_t baseN_ = 0; + int32_t baseK_ = 0; + int32_t nz_factor_ = 1; + uint32_t group_num_ = 0; + uint32_t core_num_ = 0; + uint32_t mmDataTypeSize_ = 0; + size_t sync_workspace_size_ = 0; + ge::DataType x_dtype; + uint64_t l1_size, l0a_size, l0b_size, l0c_size, ub_size; + uint32_t aic_num, aiv_num; + // SocVersion soc_version; +}; + +ge::graphStatus TilingMoeGroupedMatmulFunc::CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb) { + uint64_t available_l1_size = l1_size; + if (available_l1_size < L1_PARTA_SIZE) { + OP_LOGE(tiling_context_->GetNodeName(), "available_l1_size is less than 256k."); + return ge::GRAPH_FAILED; + } + // according to double buffer, recompute the params used for data movement from GM to L1 + uint64_t l1_a_size = baseM_ > baseN_ ? L1_PARTA_SIZE : available_l1_size - L1_PARTA_SIZE; + uint64_t l1_b_size = available_l1_size - l1_a_size; + // 2: double buffer + mm_step_ka = (l1_a_size / 2UL) / (static_cast(baseM_) * baseK_ * mmDataTypeSize_); + // 2: double buffer + mm_step_kb = (l1_b_size / 2UL) / (static_cast(baseN_) * baseK_ * mmDataTypeSize_); + if (mm_step_ka == 0 || mm_step_kb == 0) { + OP_LOGE(tiling_context_->GetNodeName(), "stepka or stepkb cannot be 0."); + return ge::GRAPH_FAILED; + } + + if (mm_step_ka > mm_step_kb) { + mm_step_ka = mm_step_ka / mm_step_kb * mm_step_kb; + } else if (mm_step_ka < mm_step_kb) { + mm_step_kb = mm_step_kb / mm_step_ka * mm_step_ka; + } + return ge::GRAPH_SUCCESS; +} + +void TilingMoeGroupedMatmulFunc::GMMGetPlatformInfo() { + auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo()); + platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1_size); + platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0a_size); + platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0b_size); + platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0c_size); + platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + aic_num = platform_info.GetCoreNumAic(); + aiv_num = platform_info.GetCoreNumAiv(); +} + +void TilingMoeGroupedMatmulFunc::InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo) { + auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo()); + platformInfo.socVersion = platform_info.GetSocVersion(); + platformInfo.l1Size = l1_size; + platformInfo.l0CSize = l0c_size; + platformInfo.ubSize = ub_size; + platformInfo.l0ASize = l0a_size; + platformInfo.l0BSize = l0b_size; +} + +ge::graphStatus TilingMoeGroupedMatmulFunc::GMMSetMMTiling() { + matmul_tiling::DataType matmul_dtype = static_cast(x_dtype); + matmul_tiling::PlatformInfo platformInfo; + InitPlatformInfo(platformInfo); + // matmul_tiling::MatmulApiTiling mm(platformInfo); + matmul_tiling::MultiCoreMatmulTiling mm(platformInfo); + mm.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, false); + if (weight_nz) { + mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_dtype, transpose_weight); + } else { + mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, transpose_weight); + } + mm.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype); + mm.SetOrgShape(m_, n_, k_); + mm.SetShape(m_, baseN_, k_); + // mm.SetShape(single_m_, single_n_, k_); + mm.SetFixSplit(baseM_, baseN_, baseK_); + mm.SetBufferSpace(l1_size, l0c_size, ub_size); + + if (mm.GetTiling(tiling_data_.mm_tiling) == -1) { + OP_LOGE(tiling_context_->GetNodeName(), "matmul getTiling failed."); + return ge::GRAPH_FAILED; + } + uint32_t mm_step_ka = 1; + uint32_t mm_step_kb = 1; + + auto ret = CalcStepKaKb(mm_step_ka, mm_step_kb); + if (ret != ge::GRAPH_SUCCESS) { + OP_LOGE(tiling_context_->GetNodeName(), "matmul calc stepka or stepkb failed."); + return ge::GRAPH_FAILED; + } + + constexpr uint32_t step_m = 1; // 1: step_m set fixed value 1 + constexpr uint32_t step_n = 1; // 1: step_n set fixed value 1 + uint32_t mm_depth_a1 = mm_step_ka * DOUBLE_BUFFER_STEPKA_STEPKB * step_m; + uint32_t mm_depth_b1 = mm_step_kb * DOUBLE_BUFFER_STEPKA_STEPKB * step_n; + tiling_data_.mm_tiling.set_shareMode(0); + tiling_data_.mm_tiling.set_dbL0C(1); // disable double buffer for LOC + tiling_data_.mm_tiling.set_baseM(baseM_); // set precomputed baseM + tiling_data_.mm_tiling.set_baseN(baseN_); // set precomputed baseN + tiling_data_.mm_tiling.set_baseK(baseK_); // set precomputed baseK + tiling_data_.mm_tiling.set_stepKa(mm_step_ka); // set precomputed mmStepKa + tiling_data_.mm_tiling.set_depthA1(mm_depth_a1); // set precomputed mmDepthA1 + tiling_data_.mm_tiling.set_stepKb(mm_step_kb); // set precomputed mmStepKb + tiling_data_.mm_tiling.set_depthB1(mm_depth_b1); // set precomputed mmDepthB1 + tiling_data_.mm_tiling.set_stepM(step_m); // set precomputed stepM + tiling_data_.mm_tiling.set_stepN(step_n); // set precomputed stepN + OP_LOGD(context->GetNodeName(), "GMM_tiling: baseM is %d, baseK is %d, baseN is %d, transpose_weight is %d, weight_nz is %d", baseM_, baseK_, baseN_, transpose_weight, weight_nz); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingMoeGroupedMatmulFunc::CalMMTiling() { + baseN_ = BEST_BASEN; + if (x_dtype == ge::DT_BF16 || x_dtype == ge::DT_FLOAT16) { + mmDataTypeSize_ = 2; + } else { + OP_LOGE(tiling_context_->GetNodeName(), "only support bf16 or fp16."); + return ge::GRAPH_FAILED; + } + + baseK_ = static_cast((l0b_size / DOUBLE_BUFFER_L0A_L0B) / (static_cast(baseN_) * mmDataTypeSize_)); + baseK_ = static_cast(SixteenAlign(static_cast(baseK_))); + uint32_t max_base_m = static_cast(l0c_size / + (static_cast(baseN_) * FP32_DATATYPE_SIZE)); + baseM_ = std::min((l0a_size / DOUBLE_BUFFER_L0A_L0B) / + (static_cast(baseK_) * mmDataTypeSize_), max_base_m); + baseM_ = baseM_ > m_ ? SixteenAlign(m_, true) : SixteenAlign(static_cast(baseM_)); + + if (baseM_ > MAX_BASEM) { + baseM_ = MAX_BASEM; + } + + if (baseM_ == 0 || baseK_ == 0) { + OP_LOGE(tiling_context_->GetNodeName(), "baseM_ or baseN_ cannot be 0."); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingMoeGroupedMatmulFunc::Init() { + GMMGetPlatformInfo(); + // only support singlex、singleweight、singley + bool is_single_x = (tiling_context_->GetDynamicInputTensor(X_INDEX, 1) == nullptr); + bool is_single_weight = (tiling_context_->GetDynamicInputTensor(WEIGHT_INDEX, 1) == nullptr); + bool is_single_y = (tiling_context_->GetOutputShape(1) == nullptr); + transpose_weight = static_cast(*(tiling_context_->GetAttrs()->GetAttrPointer(0))); + if (!(is_single_x && is_single_weight && is_single_y)) { + OP_LOGE(tiling_context_->GetNodeName(), "only support singlex and singleweight and singley."); + return ge::GRAPH_FAILED; + } + + auto x_shape = tiling_context_->GetDynamicInputShape(X_INDEX, 0)->GetOriginShape(); + auto weight_shape = tiling_context_->GetDynamicInputShape(WEIGHT_INDEX, 0)->GetOriginShape(); + auto group_list_shape = tiling_context_->GetInputShape(GROUPLIST_INDEX)->GetOriginShape(); + auto y_shape = tiling_context_->GetOutputShape(0)->GetOriginShape(); + auto weight_desc = tiling_context_->GetDynamicInputDesc(WEIGHT_INDEX, 0); + auto weight_format = static_cast(ge::GetPrimaryFormat(weight_desc->GetStorageFormat())); + x_dtype = tiling_context_->GetDynamicInputDesc(X_INDEX, 0)->GetDataType(); + + // printf("weight_format %d\n", weight_format); + weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ; + + // check input shape + if (x_shape.GetDimNum() != 2 || y_shape.GetDimNum() != 2) { + OP_LOGE(tiling_context_->GetNodeName(), "the dimNum of input and output should be 2, but got %zu, %zu.", static_cast(x_shape.GetDimNum()), static_cast(y_shape.GetDimNum())); + return ge::GRAPH_FAILED; + } + uint32_t weight_dim1, weight_dim2; + n_ = transpose_weight ? weight_shape.GetDim(1) : weight_shape.GetDim(2); + + if (group_list_shape.GetDimNum() != 2) { + OP_LOGE(tiling_context_->GetNodeName(), "only support key-value mode of groupList, the dimNum of groupList should be 2, but got %zu.", static_cast(group_list_shape.GetDimNum())); + return ge::GRAPH_FAILED; + } + + m_ = x_shape.GetDim(x_shape.GetDimNum() - 2); + k_ = x_shape.GetDim(x_shape.GetDimNum() - 1); + group_num_ = weight_shape.GetDim(0); + + if (weight_shape.GetDim(0) != group_num_) { + OP_LOGE(tiling_context_->GetNodeName(), "the dim0 of input weight should be equal to input groupList, but got %zu, %zu.", static_cast(weight_shape.GetDim(0)), static_cast(group_list_shape.GetDim(0))); + } + single_m_ = 128; + single_n_ = 256; + + core_num_ = aic_num; + auto n_task_num = (n_ + single_n_ - 1) / single_n_; + auto task_num = m_ * n_task_num; + if (task_num < core_num_) { + core_num_ = task_num; + } + + auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo()); + sync_workspace_size_ = static_cast(platform_info.GetLibApiWorkSpaceSize()); + return ge::GRAPH_SUCCESS; +} + +void TilingMoeGroupedMatmulFunc::SetTilingKey() { + uint64_t tiling_key = 10UL; + if (transpose_weight) { + tiling_key = tiling_key + 1UL; + } + tiling_context_->SetTilingKey(tiling_key); +} + +void TilingMoeGroupedMatmulFunc::FillTilingData() { + tiling_data_.set_m(static_cast(m_)); + tiling_data_.set_n(static_cast(n_)); + tiling_data_.set_k(static_cast(k_)); + tiling_data_.set_single_m(single_m_); + tiling_data_.set_single_n(single_n_); + tiling_data_.set_group_num(group_num_); + tiling_data_.set_core_num(core_num_); +} + +ge::graphStatus TilingMoeGroupedMatmulFunc::RunKernelTiling() { + auto ret = CalMMTiling(); + if (ret != ge::GRAPH_SUCCESS) { + OP_LOGE(context->GetNodeName(), "cal mmtiling failed."); + return ge::GRAPH_FAILED; + } + ret = GMMSetMMTiling(); + if (ret != ge::GRAPH_SUCCESS) { + OP_LOGE(context->GetNodeName(), "gmm set mmtiling failed."); + return ge::GRAPH_FAILED; + } + SetTilingKey(); + FillTilingData(); + size_t userWorkspaceSize = 0; + size_t *currentWorkspace = tiling_context_->GetWorkspaceSizes(1); + currentWorkspace[0] = userWorkspaceSize + sync_workspace_size_; + tiling_data_.SaveToBuffer(tiling_context_->GetRawTilingData()->GetData(), + tiling_context_->GetRawTilingData()->GetCapacity()); + tiling_context_->GetRawTilingData()->SetDataSize(tiling_data_.GetDataSize()); + tiling_context_->SetBlockDim(core_num_); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingForMoeGroupedMatmulFunc(gert::TilingContext *context){ + TilingMoeGroupedMatmulFunc tilingObject(context); + auto ret = tilingObject.Init(); + if(ret != ge::GRAPH_SUCCESS){ + OP_LOGE(context->GetNodeName(), "tiling Init failed."); + return ge::GRAPH_FAILED; + } + ret = tilingObject.RunKernelTiling(); + return ret; +} + +struct MatmulAllreduceAddRmsnormCompileInfo1 {}; +ge::graphStatus TilingParseForMatmulAllreduceAddRmsnorm1(gert::TilingParseContext *context) +{ + // (void)context; + return ge::GRAPH_SUCCESS; +} + + +IMPL_OP_OPTILING(MoeGroupedMatmul) + .Tiling(TilingForMoeGroupedMatmulFunc) + .TilingParse(TilingParseForMatmulAllreduceAddRmsnorm1); + +} // namespace optiling + diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_def.cpp b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_def.cpp new file mode 100644 index 00000000..2ed0f124 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_def.cpp @@ -0,0 +1,53 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_grouped_matmul_def.cpp + * \brief + */ + +#include "register/op_def_registry.h" +#include "moe_grouped_matmul_infershape.cpp" + +namespace ops { +class MoeGroupedMatmul : public OpDef { +public: + explicit MoeGroupedMatmul(const char *name) : OpDef(name) { + this->Input("x") + .ParamType(DYNAMIC) + .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("weight") + .ParamType(DYNAMIC) + .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("group_list") + .ParamType(REQUIRED) + .DataTypeList({ge::DT_INT64, ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}); + this->Output("y") + .ParamType(DYNAMIC) + .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("transpose_weight") + .AttrType(OPTIONAL) + .Bool(false); + this->SetInferShape(ge::InferShape); + this->SetInferDataType(ge::InferDataType); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + } + + +}; + +OP_ADD(MoeGroupedMatmul); + +} // namespace ops \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_infershape.cpp b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_infershape.cpp new file mode 100644 index 00000000..3633a59f --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_infershape.cpp @@ -0,0 +1,35 @@ +#include "register/op_impl_registry.h" + + +#include + +namespace ge { +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t WEIGHT_INDEX = 1; +constexpr uint32_t GROUPLIST_INDEX = 2; + +static ge::graphStatus InferShape(gert::InferShapeContext *context) { + const gert::Shape* x_shape = context->GetDynamicInputShape(X_INDEX, 0); + const gert::Shape* weight_shape = context->GetDynamicInputShape(WEIGHT_INDEX, 0); + bool transpose_weight = static_cast(*(context->GetAttrs()->GetAttrPointer(0))); + gert::Shape* y_shape = context->GetOutputShape(0); + *y_shape = *x_shape; + auto weight_desc = context->GetDynamicInputDesc(WEIGHT_INDEX, 0); + auto weight_format = static_cast(ge::GetPrimaryFormat(weight_desc->GetStorageFormat())); + bool weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ; + int64_t dim_n; + if (weight_nz) { + dim_n = transpose_weight ? (weight_shape->GetDim(1) * weight_shape->GetDim(3)) : + (weight_shape->GetDim(2) * weight_shape->GetDim(4)); + } else { + dim_n = transpose_weight ? weight_shape->GetDim(1) : weight_shape->GetDim(2); + } + y_shape->SetDim(1, dim_n); +} + +static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) { + const auto input_dtype = context->GetDynamicInputDataType(X_INDEX, 0); + context->SetOutputDataType(0, input_dtype); + return ge::GRAPH_SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.cpp b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.cpp new file mode 100644 index 00000000..397a9ae7 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.cpp @@ -0,0 +1,50 @@ +#include "moe_grouped_matmul_l0.h" +#include "opdev/op_log.h" +#include "opdev/op_dfx.h" +#include "opdev/shape_utils.h" +#include "opdev/make_op_executor.h" + +using namespace op; + +namespace l0op { +OP_TYPE_REGISTER(MoeGroupedMatmul); + +const aclTensorList *MoeGroupedMatmul(const aclTensorList *x, + const aclTensorList *weight, + const aclTensor *groupList, + bool transposeWeight, + op::Shape yShape, + size_t outLength, + op::DataType yDtype, + aclOpExecutor *executor) { + L0_DFX(MoeGroupedMatmul, x, weight, groupList, transposeWeight, outLength); + + std::vector tensorsVec; + const aclTensor *x0 = x->Size() > 0 ? (*x)[0] : nullptr; + if (x0 == nullptr) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "(*x)[0] is nullptr."); + return nullptr; + } + + for (size_t i(0); i < outLength; ++i) { + tensorsVec.emplace_back(executor->AllocTensor(yShape, yDtype)); + } + auto out = executor->AllocTensorList(tensorsVec.data(), outLength); + + auto x0_dim_num = x0->GetStorageShape().GetDimNum(); + auto x0_dim0 = x0->GetStorageShape().GetDim(0); + printf("x0_dim_num %d x0_dim0 %d\n", x0_dim_num, x0_dim0); + + auto ret = ADD_TO_LAUNCHER_LIST_AICORE(MoeGroupedMatmul, + OP_INPUT(x, weight, groupList), + OP_OUTPUT(out), + OP_ATTR(transposeWeight)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed."); + return nullptr; + } + + return out; +} + +} // namespace l0op \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.h b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.h new file mode 100644 index 00000000..6dff4dc8 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.h @@ -0,0 +1,27 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H +#define OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H + +#include "opdev/op_executor.h" + +namespace l0op { +const aclTensorList *MoeGroupedMatmul(const aclTensorList *x, + const aclTensorList *weight, + const aclTensor *groupList, + bool transposeWeight, + op::Shape yShape, + size_t outLength, + op::DataType yDtype, + aclOpExecutor *executor); +} + +#endif \ No newline at end of file diff --git a/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_tiling.h b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_tiling.h new file mode 100644 index 00000000..a4c99431 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_tiling.h @@ -0,0 +1,26 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(MoeGroupedMatmulTilingData) +TILING_DATA_FIELD_DEF(uint32_t, group_num); +TILING_DATA_FIELD_DEF(uint32_t, core_num); +TILING_DATA_FIELD_DEF(uint32_t, m); +TILING_DATA_FIELD_DEF(uint32_t, n); +TILING_DATA_FIELD_DEF(uint32_t, k); +TILING_DATA_FIELD_DEF(uint32_t, single_m); +TILING_DATA_FIELD_DEF(uint32_t, single_n); +TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm_tiling); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeGroupedMatmul, MoeGroupedMatmulTilingData) +} // namespace optiling diff --git a/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.cpp b/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.cpp new file mode 100644 index 00000000..c05c9969 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.cpp @@ -0,0 +1,41 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "moe_grouped_matmul.h" +#include "kernel_operator.h" + +#if defined(FORMAT_WEIGHT) && FORMAT_WEIGHT == FORMAT_FRACTAL_NZ +constexpr CubeFormat formatWeight = CubeFormat::NZ; +#else +constexpr CubeFormat formatWeight = CubeFormat::ND; +#endif + +//using namespace matmul; +#define GMM_CUBE_IMP(transWeight) \ + do { \ + if ASCEND_IS_AIV { \ + return; \ + } \ + GET_TILING_DATA(tiling_data, tiling); \ + AscendC::TPipe pipe; \ + KernelMoeGMMNoQuant op(&pipe); \ + op.Init(x, weight, group_list, y, &tiling_data); \ + op.Process(); \ + } while (0) + +extern "C" __global__ __aicore__ void moe_grouped_matmul(GM_ADDR x, GM_ADDR weight, GM_ADDR group_list, GM_ADDR y, + GM_ADDR workSpace, GM_ADDR tiling) { + + if (TILING_KEY_IS(10UL)) { + GMM_CUBE_IMP(false); + } else if (TILING_KEY_IS(11UL)) { + GMM_CUBE_IMP(true); + } +} diff --git a/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.h b/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.h new file mode 100644 index 00000000..0f085c76 --- /dev/null +++ b/csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.h @@ -0,0 +1,186 @@ +/** +* This program is free software, you can redistribute it and/or modify. +* Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "lib/matmul_intf.h" +using namespace AscendC; + + +constexpr MatmulConfig matmulCFGUnitFlag{false, false, true, 0, 0, 0, false, false, false, false, false, 0, 0, 0, + 0, 0, 0, 0, true}; +struct GMMConfig { + uint32_t m = 0; + uint32_t k = 0; + uint32_t n = 0; + uint32_t baseM = 0; + uint32_t baseN = 0; + uint32_t mIdx = 0; + uint32_t nIdx = 0; + uint32_t blockDimM = 0; + uint32_t blockDimN = 0; + uint32_t singleM = 0; + uint32_t singleN = 0; + uint64_t wBaseOffset = 0; + uint64_t nAxisBaseOffset = 0; + uint64_t mAxisBaseOffset = 0; + uint64_t xBaseOffset = 0; + uint64_t yBaseOffset = 0; + uint64_t wOutOffset = 0; +}; + + +template +class KernelMoeGMMNoQuant { + +protected: + using xType = MatmulType; + using weightType = MatmulType; + using yType = MatmulType; + using biasType = MatmulType; + using mmT = matmul::MatmulImpl; + mmT mm; + + MoeGroupedMatmulTilingData tiling_; + AscendC::TPipe *pipe_ = nullptr; + + GlobalTensor x_gm_; + GlobalTensor weight_gm_; + GlobalTensor y_gm_; + GlobalTensor group_list_gm_; + ListTensorDesc x_list_; + ListTensorDesc weight_list_; + ListTensorDesc y_list_; + + uint32_t core_idx; + uint32_t used_core_num; + constexpr static bool transposeW = transWeight; + constexpr static uint32_t UB_BLOCK_UNIT_SIZE = 32; + +public: + __aicore__ inline KernelMoeGMMNoQuant(AscendC::TPipe *pipe) {pipe_ = pipe;} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR group_list, GM_ADDR y, const MoeGroupedMatmulTilingData *tiling) { + core_idx = GetBlockIdx(); + tiling_ = *tiling; + used_core_num = GetBlockNum(); + group_list_gm_.SetGlobalBuffer((__gm__ T2*)group_list); + x_list_.Init((__gm__ void*)x); + weight_list_.Init((__gm__ void*)weight); + y_list_.Init((__gm__ void*)y); + GM_ADDR x_first_addr = (__gm__ uint8_t*)x_list_.GetDataPtr<__gm__ uint8_t>(0); + GM_ADDR weight_first_addr = (__gm__ uint8_t*)weight_list_.GetDataPtr<__gm__ uint8_t>(0); + GM_ADDR y_first_addr = (__gm__ uint8_t*)y_list_.GetDataPtr<__gm__ uint8_t>(0); + x_gm_.SetGlobalBuffer((__gm__ T*)x_first_addr); + weight_gm_.SetGlobalBuffer((__gm__ T*)weight_first_addr); + y_gm_.SetGlobalBuffer((__gm__ T*)y_first_addr); + + mm.Init(&tiling_.mm_tiling, pipe_); + } + + __aicore__ inline void Process() { + + uint32_t group_list_inner_shape = 2u; + uint32_t group_list_shape_size = tiling_.group_num * group_list_inner_shape; + GMMConfig mn_config; + + for (uint32_t loop = 0, count = 0; loop < group_list_shape_size; loop += group_list_inner_shape) { + int32_t split_value = static_cast(group_list_gm_.GetValue(loop + 1)); + if (split_value <= 0) { + break; + } + uint32_t group_idx = static_cast(group_list_gm_.GetValue(loop)); + mn_config.mAxisBaseOffset += mn_config.m; + mn_config.xBaseOffset += mn_config.m * mn_config.k; + mn_config.yBaseOffset += mn_config.m * mn_config.n; + this->SetMNConfig(split_value, mn_config); + mn_config.nAxisBaseOffset = group_idx * mn_config.n; + if constexpr (formatWeight == CubeFormat::NZ) { + mn_config.wBaseOffset = AlignUp(mn_config.k, 16) * AlignUp(mn_config.nAxisBaseOffset, 16); + } else { + mn_config.wBaseOffset = mn_config.k * mn_config.nAxisBaseOffset; + } + mn_config.blockDimM = Ceil(mn_config.m, mn_config.singleM); + mn_config.blockDimN = Ceil(mn_config.n, mn_config.singleN); + uint32_t cur_count = count + mn_config.blockDimM * mn_config.blockDimN; + uint32_t cur_block = this->core_idx >= count ? this->core_idx : this->core_idx + used_core_num; + while (cur_block < cur_count) { + mn_config.mIdx = (cur_block - count) / mn_config.blockDimN; + mn_config.nIdx = (cur_block - count) % mn_config.blockDimN; + this->MMCompute(group_idx, mn_config, this->core_idx); + cur_block += used_core_num; + } + count = cur_count % used_core_num; + } + } + +protected: + __aicore__ inline uint32_t AlignUp(uint32_t a, uint32_t base) { + return (a + base - 1) / base * base; + } + + __aicore__ inline uint32_t Ceil(uint32_t a, uint32_t base) { + if (base == 0) { + return a; + } + return (a + base - 1) / base; + } + + __aicore__ inline void SetMNConfig(const int32_t split_value, GMMConfig & mn_config) { + mn_config.m = split_value; + mn_config.k = tiling_.k; + mn_config.n = tiling_.n; + mn_config.baseM = tiling_.single_m; + mn_config.baseN = tiling_.single_n; + mn_config.singleM = mn_config.baseM; + mn_config.singleN = mn_config.baseN; + } + + __aicore__ inline void MMCompute(uint32_t group_idx, GMMConfig& mn_config, uint32_t core_idx) { + uint32_t tail_n = mn_config.nIdx * mn_config.singleN; + uint32_t cur_single_n = mn_config.nIdx < mn_config.blockDimN - 1 ? mn_config.singleN : mn_config.n - tail_n; + uint32_t cur_single_m = mn_config.mIdx < mn_config.blockDimM - 1 ? mn_config.singleM + : mn_config.m - mn_config.mIdx * mn_config.singleM; + uint64_t x_offset = mn_config.mIdx * mn_config.singleM * mn_config.k; + uint64_t out_offset = mn_config.mIdx * mn_config.singleM * mn_config.n + tail_n; + GlobalTensor weight_gm_local = GetGlobalBufferW(group_idx, tail_n, mn_config); + + mm.SetOrgShape(mn_config.m, mn_config.n, mn_config.k); + mm.SetSingleShape(cur_single_m, cur_single_n, mn_config.k); + mm.SetTensorA(x_gm_[mn_config.xBaseOffset + x_offset], false); + mm.SetTensorB(weight_gm_local, transposeW); + mm.template IterateAll(y_gm_[mn_config.yBaseOffset + out_offset], 0); + } + + __aicore__ inline GlobalTensor GetGlobalBufferW(uint32_t group_idx, uint32_t tail_n, GMMConfig& mn_config) { + uint64_t w_offset = SetWOffset(tail_n, mn_config.k); + GlobalTensor weight_gm_local; + weight_gm_local = weight_gm_[mn_config.wBaseOffset + w_offset]; + if (mn_config.blockDimM == 1) { + weight_gm_local.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + return weight_gm_local; + } + + __aicore__ inline uint64_t SetWOffset(uint32_t tail_n, uint32_t k) { + uint64_t w_offset = 0; + if constexpr (formatWeight == CubeFormat::NZ && transposeW) { + w_offset = tail_n * (UB_BLOCK_UNIT_SIZE / sizeof(T)); // 32: quant is 32, float16 is 16 + } else if constexpr (formatWeight == CubeFormat::NZ) { + w_offset = tail_n * AlignUp(k, 16); // 16: nz format last two dim size + } else if constexpr (transposeW) { + w_offset = tail_n * k; + } else { + w_offset = tail_n; + } + return w_offset; + } +}; + diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index ae60ce05..7f3c9c18 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -597,6 +597,38 @@ void transpose_kv_cache_by_block( } +// It is expected that further improvements will be made after it is incorporated into CANN on June 30th. +std::vector moe_grouped_matmul( + at::Tensor x, + at::Tensor weight, + const at::Tensor& group_list, + int64_t split_item, + int64_t group_type, + int64_t group_list_type +) +{ + bool transpose_weight = false; + bool weight_nz = true; + + at::TensorList x_list = at::TensorList(x); + at::TensorList weight_list = at::TensorList(weight); + std::vector y; + c10::TensorOptions options = x_list[0].options().dtype(x[0].scalar_type()); + auto m = x_list[0].sizes()[0]; + auto n = weight_list[0].sizes()[1]; + if (!transpose_weight) { + n = weight_list[0].sizes()[2]; + } + at::Tensor y_0 = at::empty(at::IntArrayRef{m, n}, options); + y.emplace_back(y_0); + at::TensorList result = at::TensorList(y); + + EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz, + x_list, weight_list, group_list, transpose_weight, result); + + return y; +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -779,4 +811,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()" ); ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block); + ops.def( + "moe_grouped_matmul(" + "Tensor x," + "Tensor weight," + "Tensor group_list," + "int split_item," + "int group_type," + "int group_list_type)" + + "-> Tensor[]" + ); + ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 25c2cf85..76104616 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -457,6 +457,35 @@ void transpose_kv_cache_by_block_meta( { return; } + +std::vector moe_grouped_matmul_meta( + at::Tensor x, + at::Tensor weight, + const at::Tensor& group_list, + int64_t split_item, + int64_t group_type, + int64_t group_list_type +) +{ + bool transpose_weight = false; + bool weight_nz = true; + + at::TensorList x_list = at::TensorList(x); + at::TensorList weight_list = at::TensorList(weight); + std::vector y; + c10::TensorOptions options = x[0].options().dtype(x[0].scalar_type()); + auto m = x[0].sizes()[0]; + auto n = weight[0].sizes()[1]; + if (!transpose_weight) { + n = weight[0].sizes()[2]; + } + at::Tensor y_0 = at::zeros(at::IntArrayRef{m, n}, options); + y.emplace_back(y_0); + at::TensorList result = at::TensorList(y); + + return y; +} + } // namespace meta } // namespace vllm_ascend @@ -498,5 +527,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta); // transpose_kv_cache_by_block ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta); + // moe_grouped_matmul + ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); } } diff --git a/setup.py b/setup.py index acac0958..d7a8f68c 100644 --- a/setup.py +++ b/setup.py @@ -405,8 +405,10 @@ class cmake_build_ext(build_ext): print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}") def run(self): - # First, ensure ACLNN custom-ops is built and installed. - self.run_command("build_aclnn") + if envs.COMPILE_CUSTOM_KERNELS: + # First, ensure ACLNN custom-ops is built and installed. + self.run_command("build_aclnn") + # Then, run the standard build_ext command to compile the extensions super().run()