[MOE] commit GMM custom operator (#7010)

### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios

### How was this patch tested?
Submit the GMM custom operator for subsequent integration into the MOE
process.


- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
This commit is contained in:
chenxi-hh
2026-03-09 09:56:31 +08:00
committed by GitHub
parent 01d3515dcf
commit 737dfcf638
16 changed files with 1214 additions and 3 deletions

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
add_ops_compile_options(
OP_NAME MoeGroupedMatmulCustom
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)
target_sources(optiling PRIVATE
moe_grouped_matmul_cpu.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${ASCEND_CANN_PACKAGE_PATH}/include
${ASCEND_CANN_PACKAGE_PATH}/include/external
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
)
target_sources(opsproto PRIVATE
moe_grouped_matmul_infershape.cpp
)
target_include_directories(opsproto PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${ASCEND_CANN_PACKAGE_PATH}/include
${ASCEND_CANN_PACKAGE_PATH}/include/external
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
)
target_sources(op_host_aclnnInner PRIVATE
moe_grouped_matmul_def.cpp
)
target_include_directories(op_host_aclnnInner PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${ASCEND_CANN_PACKAGE_PATH}/include
${ASCEND_CANN_PACKAGE_PATH}/include/external
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
)
target_sources(opapi PRIVATE
moe_grouped_matmul_l0.cpp
aclnn_moe_grouped_matmul.cpp
)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul.h"
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul_weight_nz.h"
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL)

View File

@@ -0,0 +1,248 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_moe_grouped_matmul.h"
#include "aclnn_moe_grouped_matmul_weight_nz.h"
#include <dlfcn.h>
#include <new>
#include "aclnn_kernels/transdata.h"
#include "moe_grouped_matmul_l0.h"
#include "aclnn_kernels/contiguous.h"
#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/platform.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/make_op_executor.h"
using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
namespace {
static constexpr size_t ALIGN_NZ_4BIT_N = 64UL;
static constexpr size_t ALIGN_NZ_4BIT_K = 64UL;
static constexpr size_t ALIGN_NZ_INT8_N = 32UL;
static constexpr size_t ALIGN_NZ_K = 16UL;
static constexpr size_t DIMS_THREE_FOR_GMM = 3UL;
static constexpr size_t LAST_FIRST_DIM_INDEX = 1;
static constexpr size_t LAST_SECOND_DIM_INDEX = 2;
static constexpr size_t LAST_THIRD_DIM_INDEX = 3;
enum class GMMWeightVersion : uint32_t {
WeightNd = 1U,
WeightNz = 2U
};
struct MoeGroupedMatmulParams {
const aclTensorList *x = nullptr;
const aclTensorList *weight = nullptr;
const aclTensor *groupTensor = nullptr;
bool transposeWeight = false;
bool isSingleWeight = false;
GMMWeightVersion weightVersion = GMMWeightVersion::WeightNd;
const aclTensorList *y = nullptr;
DataType xDtype = DataType::DT_BF16;
};
}
namespace {
static aclnnStatus CheckNotNull(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *y) {
CHECK_COND(x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "x must not be nullptr.");
CHECK_COND(x->Size() != 0, ACLNN_ERR_PARAM_INVALID, "x must not be empty tensorlist.");
CHECK_COND(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight must not be nullptr.");
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight must not be empty tensorlist.");
CHECK_COND(y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "y must not be nullptr.");
CHECK_COND(y->Size() != 0, ACLNN_ERR_PARAM_INVALID, "y must not be empty tensorlist.");
return ACLNN_SUCCESS;
}
static aclnnStatus TransWeightToNzCheckAlign(MoeGroupedMatmulParams &gmmParams, const aclTensor *weight)
{
size_t viewDimNum = weight->GetViewShape().GetDimNum();
uint64_t k = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - 1) :
weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX);
uint64_t n = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX) :
weight->GetViewShape().GetDim(viewDimNum - 1);
bool k_align = false;
bool n_align = false;
if (weight->GetDataType() == DataType::DT_BF16 || weight->GetDataType() == DataType::DT_FLOAT16) {
k_align = k % ALIGN_NZ_K == 0;
n_align = n % ALIGN_NZ_K == 0;
}
CHECK_COND(k_align == true && n_align == true, ACLNN_ERR_PARAM_INVALID,
"When weight format is FRACTAL_NZ, weight'shape(k[%lu], n[%lu]) should be divisible by the "
"following shape: BF16/FP16[16, 16]). If the weight is transposed,"
"the k/n need to be reversed.",
k, n);
return ACLNN_SUCCESS;
}
static aclnnStatus TransWeightToNz(MoeGroupedMatmulParams &gmmParams, aclOpExecutor *executor) {
const aclTensorList *&weights = gmmParams.weight;
const aclTensorList *&x = gmmParams.x;
CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "The first tensor of x is nullptr!");
size_t wLength = weights->Size();
for (size_t i(0); i < wLength; ++i) {
const aclTensor* weight = (*weights)[i];
if (weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ &&
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_16 &&
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_32) {
break;
}
TransWeightToNzCheckAlign(gmmParams, weight);
continue;
}
return ACLNN_SUCCESS;
}
static const aclTensor *SetTensorToNZFormat(const aclTensor *input, op::Shape &shape, aclOpExecutor *executor) {
auto formatTensor = executor->CreateView(input, shape, input->GetViewOffset());
formatTensor->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
formatTensor->SetOriginalFormat(input->GetViewFormat());
formatTensor->SetViewShape(input->GetViewShape());
return formatTensor;
}
static aclnnStatus DataContiguous(const aclTensorList *&tensors, aclOpExecutor *executor) {
std::vector<const aclTensor *> tensorsVec;
const aclTensor *contiguousTensor = nullptr;
for (size_t i = 0; i < tensors->Size(); ++i) {
const aclTensor *tensor = (*tensors)[i];
contiguousTensor = l0op::Contiguous(tensor, executor);
CHECK_RET(contiguousTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
tensorsVec.push_back(contiguousTensor);
}
tensors = executor->AllocTensorList(tensorsVec.data(), tensorsVec.size());
return ACLNN_SUCCESS;
}
static aclnnStatus ParamsDataContiguous(MoeGroupedMatmulParams &params, aclOpExecutor *executorPtr) {
CHECK_COND(DataContiguous(params.x, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous x failed."); // make x contiguous
DataType xDtype = (*params.x)[0]->GetDataType();
DataType weightDtype = (*params.weight)[0]->GetDataType();
CHECK_COND(DataContiguous(params.weight, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"Contiguous weight failed."); // make w contiguous
params.groupTensor = l0op::Contiguous(params.groupTensor, executorPtr);
CHECK_COND(params.groupTensor != nullptr, ACLNN_ERR_PARAM_INVALID,
"Contiguous groupTensor failed.");
return ACLNN_SUCCESS;
}
static aclnnStatus GetGMMResultByL0Api(MoeGroupedMatmulParams &params, uint64_t *workspaceSize, aclOpExecutor **executor) {
auto uniqueExecutor = CREATE_EXECUTOR(); // fixed written style, create OpExecutor
aclOpExecutor *executorPtr = uniqueExecutor.get();
CHECK_RET(executorPtr != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
// op::Shape wqbmmNzShape = (*params.weight)[0]->GetStorageShape();
CHECK_COND(ParamsDataContiguous(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"ParamsDataContiguous failed.");
if (params.weightVersion == GMMWeightVersion::WeightNz) {
std::vector<const aclTensor *> tensorsVec;
for (size_t i = 0; i < params.weight->Size(); ++i) {
const aclTensor *tensor = (*params.weight)[i];
op::Shape weightNzShape = tensor->GetViewShape();
tensor = SetTensorToNZFormat(tensor, weightNzShape, executorPtr);
tensorsVec.push_back(tensor);
}
params.weight = executorPtr->AllocTensorList(tensorsVec.data(), tensorsVec.size());
}
CHECK_COND(TransWeightToNz(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
"TransWeightToNz failed.");
// Invoke l0 operator MoeGroupedMatmul for calculation.
auto result = l0op::MoeGroupedMatmul(params.x, params.weight,
params.groupTensor, params.transposeWeight,
(*params.y)[0]->GetViewShape(), params.y->Size(),
(*params.y)[0]->GetDataType(), executorPtr);
CHECK_RET(result != nullptr, ACLNN_ERR_INNER_NULLPTR);
for (size_t i(0); i < params.y->Size(); ++i) {
auto viewCopyResult = l0op::ViewCopy((*result)[i], (*params.y)[i], executorPtr);
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
}
// Standard syntax, get the size of workspace needed during computation.
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
uniqueExecutor.ReleaseTo(executor);
return ACLNN_SUCCESS;
}
static aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(const aclTensorList *x, const aclTensorList *weight,
const aclTensor *groupList, bool transposeWeight, GMMWeightVersion weightVersion,
const aclTensorList *y, uint64_t *workspaceSize, aclOpExecutor **executor) {
DataType xDtype = DataType::DT_UNDEFINED;
for (size_t i = 0; i < x->Size(); ++i) {
if ((*x)[i] != nullptr) {
xDtype = (*x)[i]->GetDataType();
break;
}
}
MoeGroupedMatmulParams moeGmmParams{x, weight, groupList, transposeWeight, true, weightVersion, y, xDtype};
aclnnStatus ret = GetGMMResultByL0Api(moeGmmParams, workspaceSize, executor);
return ret;
}
}
aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensor *groupList, bool transposeWeight, aclTensorList *out,
uint64_t *workspaceSize, aclOpExecutor **executor) {
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"one of required inputs is nullptr.");
// Standard syntax, Check parameters.
L2_DFX_PHASE_1(aclnnMoeGroupedMatmulWeightNz,
DFX_IN(x, weight, groupList),
DFX_OUT(out));
return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNz, out, workspaceSize, executor);
}
aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
const aclTensor *groupList, bool transposeWeight, aclTensorList *out,
uint64_t *workspaceSize, aclOpExecutor **executor) {
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
"one of required inputs is nullptr.");
// Standard syntax, Check parameters.
L2_DFX_PHASE_1(aclnnMoeGroupedMatmul,
DFX_IN(x, weight, groupList),
DFX_OUT(out));
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight should not be null tensorlist ");
return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNd, out, workspaceSize, executor);
}
aclnnStatus aclnnMoeGroupedMatmul(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
L2_DFX_PHASE_2(aclnnMoeGroupedMatmul);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
aclnnStatus aclnnMoeGroupedMatmulWeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
aclrtStream stream) {
L2_DFX_PHASE_2(aclnnMoeGroupedMatmulWeightNz);
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
"This is an error in GMM launch aicore");
return ACLNN_SUCCESS;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,30 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_MOE_GROUPED_MATMUL_H
#define OP_API_INC_MOE_GROUPED_MATMUL_H
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize(
const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList,
bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor);
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,30 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H
#define OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H
#include "aclnn/aclnn_base.h"
#ifdef __cplusplus
extern "C" {
#endif
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize(
const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList,
bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor);
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNz(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,336 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "moe_grouped_matmul_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#define OP_LOGD(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n")
#define OP_LOGE(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n")
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t GROUPLIST_INDEX = 2;
namespace optiling {
constexpr uint64_t BEST_L1_PARTA = 256UL * 1024UL;
constexpr uint64_t BEST_L1_PARTB = 128UL * 1024UL;
constexpr uint64_t L1_PARTA_SIZE = 256UL * 1024UL;
constexpr int32_t BEST_BASEN = 256;
constexpr uint64_t DOUBLE_BUFFER_L0A_L0B = 2;
constexpr uint64_t DOUBLE_BUFFER_STEPKA_STEPKB = 2;
constexpr uint32_t FP32_DATATYPE_SIZE = 4;
constexpr int32_t MAX_BASEM = 256;
static inline uint32_t SixteenAlign(uint32_t a, bool up = false) {
if (up) {
a += 15U; // 15: 16 bytes up-align
}
return a & ~15U; // ~15: 16 bytes down-align
}
static inline int64_t SixteenAlign(int64_t a, bool up = false) {
if (up) {
a += 15; // 15: 16 bytes up-align
}
return a & ~15; // ~15: 16 bytes down-align
}
class TilingMoeGroupedMatmulFunc {
public:
explicit TilingMoeGroupedMatmulFunc(gert::TilingContext* tiling_context)
: tiling_context_(tiling_context) {}
ge::graphStatus Init();
ge::graphStatus RunKernelTiling();
private:
MoeGroupedMatmulTilingData tiling_data_;
gert::TilingContext* tiling_context_ = nullptr;
void SetTilingKey();
void FillTilingData();
ge::graphStatus CalMMTiling();
ge::graphStatus GMMSetMMTiling();
ge::graphStatus CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb);
ge::graphStatus DynamicTIlingSingleN();
void InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo);
void GMMGetPlatformInfo();
int64_t m_ = 0L;
int64_t n_ = 0L;
int64_t k_ = 0L;
bool transpose_weight = false;
bool weight_nz = false;
uint32_t single_m_ = 0;
uint32_t single_n_ = 0;
int32_t baseM_ = 0;
int32_t baseN_ = 0;
int32_t baseK_ = 0;
int32_t nz_factor_ = 1;
uint32_t group_num_ = 0;
uint32_t core_num_ = 0;
uint32_t mmDataTypeSize_ = 0;
size_t sync_workspace_size_ = 0;
ge::DataType x_dtype;
uint64_t l1_size, l0a_size, l0b_size, l0c_size, ub_size;
uint32_t aic_num, aiv_num;
// SocVersion soc_version;
};
ge::graphStatus TilingMoeGroupedMatmulFunc::CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb) {
uint64_t available_l1_size = l1_size;
if (available_l1_size < L1_PARTA_SIZE) {
OP_LOGE(tiling_context_->GetNodeName(), "available_l1_size is less than 256k.");
return ge::GRAPH_FAILED;
}
// according to double buffer, recompute the params used for data movement from GM to L1
uint64_t l1_a_size = baseM_ > baseN_ ? L1_PARTA_SIZE : available_l1_size - L1_PARTA_SIZE;
uint64_t l1_b_size = available_l1_size - l1_a_size;
// 2: double buffer
mm_step_ka = (l1_a_size / 2UL) / (static_cast<uint64_t>(baseM_) * baseK_ * mmDataTypeSize_);
// 2: double buffer
mm_step_kb = (l1_b_size / 2UL) / (static_cast<uint64_t>(baseN_) * baseK_ * mmDataTypeSize_);
if (mm_step_ka == 0 || mm_step_kb == 0) {
OP_LOGE(tiling_context_->GetNodeName(), "stepka or stepkb cannot be 0.");
return ge::GRAPH_FAILED;
}
if (mm_step_ka > mm_step_kb) {
mm_step_ka = mm_step_ka / mm_step_kb * mm_step_kb;
} else if (mm_step_ka < mm_step_kb) {
mm_step_kb = mm_step_kb / mm_step_ka * mm_step_ka;
}
return ge::GRAPH_SUCCESS;
}
void TilingMoeGroupedMatmulFunc::GMMGetPlatformInfo() {
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1_size);
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0a_size);
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0b_size);
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0c_size);
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size);
aic_num = platform_info.GetCoreNumAic();
aiv_num = platform_info.GetCoreNumAiv();
}
void TilingMoeGroupedMatmulFunc::InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo) {
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
platformInfo.socVersion = platform_info.GetSocVersion();
platformInfo.l1Size = l1_size;
platformInfo.l0CSize = l0c_size;
platformInfo.ubSize = ub_size;
platformInfo.l0ASize = l0a_size;
platformInfo.l0BSize = l0b_size;
}
ge::graphStatus TilingMoeGroupedMatmulFunc::GMMSetMMTiling() {
matmul_tiling::DataType matmul_dtype = static_cast<matmul_tiling::DataType>(x_dtype);
matmul_tiling::PlatformInfo platformInfo;
InitPlatformInfo(platformInfo);
// matmul_tiling::MatmulApiTiling mm(platformInfo);
matmul_tiling::MultiCoreMatmulTiling mm(platformInfo);
mm.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, false);
if (weight_nz) {
mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_dtype, transpose_weight);
} else {
mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, transpose_weight);
}
mm.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype);
mm.SetOrgShape(m_, n_, k_);
mm.SetShape(m_, baseN_, k_);
// mm.SetShape(single_m_, single_n_, k_);
mm.SetFixSplit(baseM_, baseN_, baseK_);
mm.SetBufferSpace(l1_size, l0c_size, ub_size);
if (mm.GetTiling(tiling_data_.mm_tiling) == -1) {
OP_LOGE(tiling_context_->GetNodeName(), "matmul getTiling failed.");
return ge::GRAPH_FAILED;
}
uint32_t mm_step_ka = 1;
uint32_t mm_step_kb = 1;
auto ret = CalcStepKaKb(mm_step_ka, mm_step_kb);
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(tiling_context_->GetNodeName(), "matmul calc stepka or stepkb failed.");
return ge::GRAPH_FAILED;
}
constexpr uint32_t step_m = 1; // 1: step_m set fixed value 1
constexpr uint32_t step_n = 1; // 1: step_n set fixed value 1
uint32_t mm_depth_a1 = mm_step_ka * DOUBLE_BUFFER_STEPKA_STEPKB * step_m;
uint32_t mm_depth_b1 = mm_step_kb * DOUBLE_BUFFER_STEPKA_STEPKB * step_n;
tiling_data_.mm_tiling.set_shareMode(0);
tiling_data_.mm_tiling.set_dbL0C(1); // disable double buffer for LOC
tiling_data_.mm_tiling.set_baseM(baseM_); // set precomputed baseM
tiling_data_.mm_tiling.set_baseN(baseN_); // set precomputed baseN
tiling_data_.mm_tiling.set_baseK(baseK_); // set precomputed baseK
tiling_data_.mm_tiling.set_stepKa(mm_step_ka); // set precomputed mmStepKa
tiling_data_.mm_tiling.set_depthA1(mm_depth_a1); // set precomputed mmDepthA1
tiling_data_.mm_tiling.set_stepKb(mm_step_kb); // set precomputed mmStepKb
tiling_data_.mm_tiling.set_depthB1(mm_depth_b1); // set precomputed mmDepthB1
tiling_data_.mm_tiling.set_stepM(step_m); // set precomputed stepM
tiling_data_.mm_tiling.set_stepN(step_n); // set precomputed stepN
OP_LOGD(context->GetNodeName(), "GMM_tiling: baseM is %d, baseK is %d, baseN is %d, transpose_weight is %d, weight_nz is %d", baseM_, baseK_, baseN_, transpose_weight, weight_nz);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingMoeGroupedMatmulFunc::CalMMTiling() {
baseN_ = BEST_BASEN;
if (x_dtype == ge::DT_BF16 || x_dtype == ge::DT_FLOAT16) {
mmDataTypeSize_ = 2;
} else {
OP_LOGE(tiling_context_->GetNodeName(), "only support bf16 or fp16.");
return ge::GRAPH_FAILED;
}
baseK_ = static_cast<int32_t>((l0b_size / DOUBLE_BUFFER_L0A_L0B) / (static_cast<uint32_t>(baseN_) * mmDataTypeSize_));
baseK_ = static_cast<int32_t>(SixteenAlign(static_cast<int64_t>(baseK_)));
uint32_t max_base_m = static_cast<uint32_t>(l0c_size /
(static_cast<uint32_t>(baseN_) * FP32_DATATYPE_SIZE));
baseM_ = std::min<uint32_t>((l0a_size / DOUBLE_BUFFER_L0A_L0B) /
(static_cast<uint32_t>(baseK_) * mmDataTypeSize_), max_base_m);
baseM_ = baseM_ > m_ ? SixteenAlign(m_, true) : SixteenAlign(static_cast<uint32_t>(baseM_));
if (baseM_ > MAX_BASEM) {
baseM_ = MAX_BASEM;
}
if (baseM_ == 0 || baseK_ == 0) {
OP_LOGE(tiling_context_->GetNodeName(), "baseM_ or baseN_ cannot be 0.");
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus TilingMoeGroupedMatmulFunc::Init() {
GMMGetPlatformInfo();
// only support singlex、singleweight、singley
bool is_single_x = (tiling_context_->GetDynamicInputTensor(X_INDEX, 1) == nullptr);
bool is_single_weight = (tiling_context_->GetDynamicInputTensor(WEIGHT_INDEX, 1) == nullptr);
bool is_single_y = (tiling_context_->GetOutputShape(1) == nullptr);
transpose_weight = static_cast<bool>(*(tiling_context_->GetAttrs()->GetAttrPointer<bool>(0)));
if (!(is_single_x && is_single_weight && is_single_y)) {
OP_LOGE(tiling_context_->GetNodeName(), "only support singlex and singleweight and singley.");
return ge::GRAPH_FAILED;
}
auto x_shape = tiling_context_->GetDynamicInputShape(X_INDEX, 0)->GetOriginShape();
auto weight_shape = tiling_context_->GetDynamicInputShape(WEIGHT_INDEX, 0)->GetOriginShape();
auto group_list_shape = tiling_context_->GetInputShape(GROUPLIST_INDEX)->GetOriginShape();
auto y_shape = tiling_context_->GetOutputShape(0)->GetOriginShape();
auto weight_desc = tiling_context_->GetDynamicInputDesc(WEIGHT_INDEX, 0);
auto weight_format = static_cast<ge::Format>(ge::GetPrimaryFormat(weight_desc->GetStorageFormat()));
x_dtype = tiling_context_->GetDynamicInputDesc(X_INDEX, 0)->GetDataType();
// printf("weight_format %d\n", weight_format);
weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ;
// check input shape
if (x_shape.GetDimNum() != 2 || y_shape.GetDimNum() != 2) {
OP_LOGE(tiling_context_->GetNodeName(), "the dimNum of input and output should be 2, but got %zu, %zu.", static_cast<size_t>(x_shape.GetDimNum()), static_cast<size_t>(y_shape.GetDimNum()));
return ge::GRAPH_FAILED;
}
uint32_t weight_dim1, weight_dim2;
n_ = transpose_weight ? weight_shape.GetDim(1) : weight_shape.GetDim(2);
if (group_list_shape.GetDimNum() != 2) {
OP_LOGE(tiling_context_->GetNodeName(), "only support key-value mode of groupList, the dimNum of groupList should be 2, but got %zu.", static_cast<size_t>(group_list_shape.GetDimNum()));
return ge::GRAPH_FAILED;
}
m_ = x_shape.GetDim(x_shape.GetDimNum() - 2);
k_ = x_shape.GetDim(x_shape.GetDimNum() - 1);
group_num_ = weight_shape.GetDim(0);
if (weight_shape.GetDim(0) != group_num_) {
OP_LOGE(tiling_context_->GetNodeName(), "the dim0 of input weight should be equal to input groupList, but got %zu, %zu.", static_cast<size_t>(weight_shape.GetDim(0)), static_cast<size_t>(group_list_shape.GetDim(0)));
}
single_m_ = 128;
single_n_ = 256;
core_num_ = aic_num;
auto n_task_num = (n_ + single_n_ - 1) / single_n_;
auto task_num = m_ * n_task_num;
if (task_num < core_num_) {
core_num_ = task_num;
}
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
sync_workspace_size_ = static_cast<size_t>(platform_info.GetLibApiWorkSpaceSize());
return ge::GRAPH_SUCCESS;
}
void TilingMoeGroupedMatmulFunc::SetTilingKey() {
uint64_t tiling_key = 10UL;
if (transpose_weight) {
tiling_key = tiling_key + 1UL;
}
tiling_context_->SetTilingKey(tiling_key);
}
void TilingMoeGroupedMatmulFunc::FillTilingData() {
tiling_data_.set_m(static_cast<uint32_t>(m_));
tiling_data_.set_n(static_cast<uint32_t>(n_));
tiling_data_.set_k(static_cast<uint32_t>(k_));
tiling_data_.set_single_m(single_m_);
tiling_data_.set_single_n(single_n_);
tiling_data_.set_group_num(group_num_);
tiling_data_.set_core_num(core_num_);
}
ge::graphStatus TilingMoeGroupedMatmulFunc::RunKernelTiling() {
auto ret = CalMMTiling();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(context->GetNodeName(), "cal mmtiling failed.");
return ge::GRAPH_FAILED;
}
ret = GMMSetMMTiling();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(context->GetNodeName(), "gmm set mmtiling failed.");
return ge::GRAPH_FAILED;
}
SetTilingKey();
FillTilingData();
size_t userWorkspaceSize = 0;
size_t *currentWorkspace = tiling_context_->GetWorkspaceSizes(1);
currentWorkspace[0] = userWorkspaceSize + sync_workspace_size_;
tiling_data_.SaveToBuffer(tiling_context_->GetRawTilingData()->GetData(),
tiling_context_->GetRawTilingData()->GetCapacity());
tiling_context_->GetRawTilingData()->SetDataSize(tiling_data_.GetDataSize());
tiling_context_->SetBlockDim(core_num_);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingForMoeGroupedMatmulFunc(gert::TilingContext *context){
TilingMoeGroupedMatmulFunc tilingObject(context);
auto ret = tilingObject.Init();
if(ret != ge::GRAPH_SUCCESS){
OP_LOGE(context->GetNodeName(), "tiling Init failed.");
return ge::GRAPH_FAILED;
}
ret = tilingObject.RunKernelTiling();
return ret;
}
struct MatmulAllreduceAddRmsnormCompileInfo1 {};
ge::graphStatus TilingParseForMatmulAllreduceAddRmsnorm1(gert::TilingParseContext *context)
{
// (void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(MoeGroupedMatmul)
.Tiling(TilingForMoeGroupedMatmulFunc)
.TilingParse<MatmulAllreduceAddRmsnormCompileInfo1>(TilingParseForMatmulAllreduceAddRmsnorm1);
} // namespace optiling

View File

@@ -0,0 +1,53 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_grouped_matmul_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
#include "moe_grouped_matmul_infershape.cpp"
namespace ops {
class MoeGroupedMatmul : public OpDef {
public:
explicit MoeGroupedMatmul(const char *name) : OpDef(name) {
this->Input("x")
.ParamType(DYNAMIC)
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("group_list")
.ParamType(REQUIRED)
.DataTypeList({ge::DT_INT64, ge::DT_INT32})
.FormatList({ge::FORMAT_ND});
this->Output("y")
.ParamType(DYNAMIC)
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("transpose_weight")
.AttrType(OPTIONAL)
.Bool(false);
this->SetInferShape(ge::InferShape);
this->SetInferDataType(ge::InferDataType);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(MoeGroupedMatmul);
} // namespace ops

View File

@@ -0,0 +1,35 @@
#include "register/op_impl_registry.h"
#include <string>
namespace ge {
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t GROUPLIST_INDEX = 2;
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
const gert::Shape* x_shape = context->GetDynamicInputShape(X_INDEX, 0);
const gert::Shape* weight_shape = context->GetDynamicInputShape(WEIGHT_INDEX, 0);
bool transpose_weight = static_cast<bool>(*(context->GetAttrs()->GetAttrPointer<bool>(0)));
gert::Shape* y_shape = context->GetOutputShape(0);
*y_shape = *x_shape;
auto weight_desc = context->GetDynamicInputDesc(WEIGHT_INDEX, 0);
auto weight_format = static_cast<ge::Format>(ge::GetPrimaryFormat(weight_desc->GetStorageFormat()));
bool weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ;
int64_t dim_n;
if (weight_nz) {
dim_n = transpose_weight ? (weight_shape->GetDim(1) * weight_shape->GetDim(3)) :
(weight_shape->GetDim(2) * weight_shape->GetDim(4));
} else {
dim_n = transpose_weight ? weight_shape->GetDim(1) : weight_shape->GetDim(2);
}
y_shape->SetDim(1, dim_n);
}
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) {
const auto input_dtype = context->GetDynamicInputDataType(X_INDEX, 0);
context->SetOutputDataType(0, input_dtype);
return ge::GRAPH_SUCCESS;
}
} // namespace ge

View File

@@ -0,0 +1,50 @@
#include "moe_grouped_matmul_l0.h"
#include "opdev/op_log.h"
#include "opdev/op_dfx.h"
#include "opdev/shape_utils.h"
#include "opdev/make_op_executor.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(MoeGroupedMatmul);
const aclTensorList *MoeGroupedMatmul(const aclTensorList *x,
const aclTensorList *weight,
const aclTensor *groupList,
bool transposeWeight,
op::Shape yShape,
size_t outLength,
op::DataType yDtype,
aclOpExecutor *executor) {
L0_DFX(MoeGroupedMatmul, x, weight, groupList, transposeWeight, outLength);
std::vector<const aclTensor*> tensorsVec;
const aclTensor *x0 = x->Size() > 0 ? (*x)[0] : nullptr;
if (x0 == nullptr) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "(*x)[0] is nullptr.");
return nullptr;
}
for (size_t i(0); i < outLength; ++i) {
tensorsVec.emplace_back(executor->AllocTensor(yShape, yDtype));
}
auto out = executor->AllocTensorList(tensorsVec.data(), outLength);
auto x0_dim_num = x0->GetStorageShape().GetDimNum();
auto x0_dim0 = x0->GetStorageShape().GetDim(0);
printf("x0_dim_num %d x0_dim0 %d\n", x0_dim_num, x0_dim0);
auto ret = ADD_TO_LAUNCHER_LIST_AICORE(MoeGroupedMatmul,
OP_INPUT(x, weight, groupList),
OP_OUTPUT(out),
OP_ATTR(transposeWeight));
if (ret != ACLNN_SUCCESS) {
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed.");
return nullptr;
}
return out;
}
} // namespace l0op

View File

@@ -0,0 +1,27 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H
#define OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H
#include "opdev/op_executor.h"
namespace l0op {
const aclTensorList *MoeGroupedMatmul(const aclTensorList *x,
const aclTensorList *weight,
const aclTensor *groupList,
bool transposeWeight,
op::Shape yShape,
size_t outLength,
op::DataType yDtype,
aclOpExecutor *executor);
}
#endif

View File

@@ -0,0 +1,26 @@
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(MoeGroupedMatmulTilingData)
TILING_DATA_FIELD_DEF(uint32_t, group_num);
TILING_DATA_FIELD_DEF(uint32_t, core_num);
TILING_DATA_FIELD_DEF(uint32_t, m);
TILING_DATA_FIELD_DEF(uint32_t, n);
TILING_DATA_FIELD_DEF(uint32_t, k);
TILING_DATA_FIELD_DEF(uint32_t, single_m);
TILING_DATA_FIELD_DEF(uint32_t, single_n);
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm_tiling);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(MoeGroupedMatmul, MoeGroupedMatmulTilingData)
} // namespace optiling