[MOE] commit GMM custom operator (#7010)
### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios
### How was this patch tested?
Submit the GMM custom operator for subsequent integration into the MOE
process.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -82,6 +82,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"add_rms_norm_bias"
|
||||
"apply_top_k_top_p_custom"
|
||||
"transpose_kv_cache_by_block"
|
||||
"moe_grouped_matmul"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
71
csrc/moe_grouped_matmul/op_host/CMakeLists.txt
Normal file
71
csrc/moe_grouped_matmul/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MoeGroupedMatmulCustom
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
moe_grouped_matmul_cpu.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/external
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
moe_grouped_matmul_infershape.cpp
|
||||
)
|
||||
|
||||
target_include_directories(opsproto PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/external
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
moe_grouped_matmul_def.cpp
|
||||
)
|
||||
|
||||
target_include_directories(op_host_aclnnInner PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/external
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/metadef
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime
|
||||
${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
moe_grouped_matmul_l0.cpp
|
||||
aclnn_moe_grouped_matmul.cpp
|
||||
)
|
||||
|
||||
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul.h"
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL)
|
||||
|
||||
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_grouped_matmul_weight_nz.h"
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL)
|
||||
248
csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.cpp
Normal file
248
csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.cpp
Normal file
@@ -0,0 +1,248 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "aclnn_moe_grouped_matmul.h"
|
||||
#include "aclnn_moe_grouped_matmul_weight_nz.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <new>
|
||||
|
||||
#include "aclnn_kernels/transdata.h"
|
||||
#include "moe_grouped_matmul_l0.h"
|
||||
#include "aclnn_kernels/contiguous.h"
|
||||
#include "acl/acl.h"
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "aclnn_kernels/common/op_error_check.h"
|
||||
#include "opdev/common_types.h"
|
||||
#include "opdev/data_type_utils.h"
|
||||
#include "opdev/format_utils.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/op_executor.h"
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/platform.h"
|
||||
#include "opdev/shape_utils.h"
|
||||
#include "opdev/tensor_view_utils.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
|
||||
using namespace op;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
static constexpr size_t ALIGN_NZ_4BIT_N = 64UL;
|
||||
static constexpr size_t ALIGN_NZ_4BIT_K = 64UL;
|
||||
static constexpr size_t ALIGN_NZ_INT8_N = 32UL;
|
||||
static constexpr size_t ALIGN_NZ_K = 16UL;
|
||||
static constexpr size_t DIMS_THREE_FOR_GMM = 3UL;
|
||||
static constexpr size_t LAST_FIRST_DIM_INDEX = 1;
|
||||
static constexpr size_t LAST_SECOND_DIM_INDEX = 2;
|
||||
static constexpr size_t LAST_THIRD_DIM_INDEX = 3;
|
||||
|
||||
enum class GMMWeightVersion : uint32_t {
|
||||
WeightNd = 1U,
|
||||
WeightNz = 2U
|
||||
};
|
||||
|
||||
struct MoeGroupedMatmulParams {
|
||||
const aclTensorList *x = nullptr;
|
||||
const aclTensorList *weight = nullptr;
|
||||
const aclTensor *groupTensor = nullptr;
|
||||
bool transposeWeight = false;
|
||||
bool isSingleWeight = false;
|
||||
GMMWeightVersion weightVersion = GMMWeightVersion::WeightNd;
|
||||
const aclTensorList *y = nullptr;
|
||||
DataType xDtype = DataType::DT_BF16;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
static aclnnStatus CheckNotNull(const aclTensorList *x, const aclTensorList *weight, const aclTensorList *y) {
|
||||
CHECK_COND(x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "x must not be nullptr.");
|
||||
CHECK_COND(x->Size() != 0, ACLNN_ERR_PARAM_INVALID, "x must not be empty tensorlist.");
|
||||
CHECK_COND(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight must not be nullptr.");
|
||||
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight must not be empty tensorlist.");
|
||||
CHECK_COND(y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "y must not be nullptr.");
|
||||
CHECK_COND(y->Size() != 0, ACLNN_ERR_PARAM_INVALID, "y must not be empty tensorlist.");
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static aclnnStatus TransWeightToNzCheckAlign(MoeGroupedMatmulParams &gmmParams, const aclTensor *weight)
|
||||
{
|
||||
size_t viewDimNum = weight->GetViewShape().GetDimNum();
|
||||
uint64_t k = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - 1) :
|
||||
weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX);
|
||||
uint64_t n = gmmParams.transposeWeight ? weight->GetViewShape().GetDim(viewDimNum - LAST_SECOND_DIM_INDEX) :
|
||||
weight->GetViewShape().GetDim(viewDimNum - 1);
|
||||
bool k_align = false;
|
||||
bool n_align = false;
|
||||
if (weight->GetDataType() == DataType::DT_BF16 || weight->GetDataType() == DataType::DT_FLOAT16) {
|
||||
k_align = k % ALIGN_NZ_K == 0;
|
||||
n_align = n % ALIGN_NZ_K == 0;
|
||||
}
|
||||
CHECK_COND(k_align == true && n_align == true, ACLNN_ERR_PARAM_INVALID,
|
||||
"When weight format is FRACTAL_NZ, weight'shape(k[%lu], n[%lu]) should be divisible by the "
|
||||
"following shape: BF16/FP16[16, 16]). If the weight is transposed,"
|
||||
"the k/n need to be reversed.",
|
||||
k, n);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static aclnnStatus TransWeightToNz(MoeGroupedMatmulParams &gmmParams, aclOpExecutor *executor) {
|
||||
const aclTensorList *&weights = gmmParams.weight;
|
||||
const aclTensorList *&x = gmmParams.x;
|
||||
CHECK_COND((*x)[0] != nullptr, ACLNN_ERR_PARAM_INVALID, "The first tensor of x is nullptr!");
|
||||
size_t wLength = weights->Size();
|
||||
for (size_t i(0); i < wLength; ++i) {
|
||||
const aclTensor* weight = (*weights)[i];
|
||||
if (weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ &&
|
||||
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_16 &&
|
||||
weight->GetStorageFormat() != op::Format::FORMAT_FRACTAL_NZ_C0_32) {
|
||||
break;
|
||||
}
|
||||
TransWeightToNzCheckAlign(gmmParams, weight);
|
||||
continue;
|
||||
}
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static const aclTensor *SetTensorToNZFormat(const aclTensor *input, op::Shape &shape, aclOpExecutor *executor) {
|
||||
auto formatTensor = executor->CreateView(input, shape, input->GetViewOffset());
|
||||
formatTensor->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ);
|
||||
formatTensor->SetOriginalFormat(input->GetViewFormat());
|
||||
formatTensor->SetViewShape(input->GetViewShape());
|
||||
return formatTensor;
|
||||
}
|
||||
|
||||
static aclnnStatus DataContiguous(const aclTensorList *&tensors, aclOpExecutor *executor) {
|
||||
std::vector<const aclTensor *> tensorsVec;
|
||||
const aclTensor *contiguousTensor = nullptr;
|
||||
for (size_t i = 0; i < tensors->Size(); ++i) {
|
||||
const aclTensor *tensor = (*tensors)[i];
|
||||
contiguousTensor = l0op::Contiguous(tensor, executor);
|
||||
CHECK_RET(contiguousTensor != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
tensorsVec.push_back(contiguousTensor);
|
||||
}
|
||||
tensors = executor->AllocTensorList(tensorsVec.data(), tensorsVec.size());
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static aclnnStatus ParamsDataContiguous(MoeGroupedMatmulParams ¶ms, aclOpExecutor *executorPtr) {
|
||||
CHECK_COND(DataContiguous(params.x, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
|
||||
"Contiguous x failed."); // make x contiguous
|
||||
DataType xDtype = (*params.x)[0]->GetDataType();
|
||||
DataType weightDtype = (*params.weight)[0]->GetDataType();
|
||||
CHECK_COND(DataContiguous(params.weight, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
|
||||
"Contiguous weight failed."); // make w contiguous
|
||||
params.groupTensor = l0op::Contiguous(params.groupTensor, executorPtr);
|
||||
CHECK_COND(params.groupTensor != nullptr, ACLNN_ERR_PARAM_INVALID,
|
||||
"Contiguous groupTensor failed.");
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
static aclnnStatus GetGMMResultByL0Api(MoeGroupedMatmulParams ¶ms, uint64_t *workspaceSize, aclOpExecutor **executor) {
|
||||
auto uniqueExecutor = CREATE_EXECUTOR(); // fixed written style, create OpExecutor
|
||||
aclOpExecutor *executorPtr = uniqueExecutor.get();
|
||||
CHECK_RET(executorPtr != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);
|
||||
// op::Shape wqbmmNzShape = (*params.weight)[0]->GetStorageShape();
|
||||
|
||||
CHECK_COND(ParamsDataContiguous(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
|
||||
"ParamsDataContiguous failed.");
|
||||
if (params.weightVersion == GMMWeightVersion::WeightNz) {
|
||||
std::vector<const aclTensor *> tensorsVec;
|
||||
for (size_t i = 0; i < params.weight->Size(); ++i) {
|
||||
const aclTensor *tensor = (*params.weight)[i];
|
||||
op::Shape weightNzShape = tensor->GetViewShape();
|
||||
tensor = SetTensorToNZFormat(tensor, weightNzShape, executorPtr);
|
||||
tensorsVec.push_back(tensor);
|
||||
}
|
||||
params.weight = executorPtr->AllocTensorList(tensorsVec.data(), tensorsVec.size());
|
||||
}
|
||||
CHECK_COND(TransWeightToNz(params, executorPtr) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
|
||||
"TransWeightToNz failed.");
|
||||
// Invoke l0 operator MoeGroupedMatmul for calculation.
|
||||
auto result = l0op::MoeGroupedMatmul(params.x, params.weight,
|
||||
params.groupTensor, params.transposeWeight,
|
||||
(*params.y)[0]->GetViewShape(), params.y->Size(),
|
||||
(*params.y)[0]->GetDataType(), executorPtr);
|
||||
CHECK_RET(result != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
for (size_t i(0); i < params.y->Size(); ++i) {
|
||||
auto viewCopyResult = l0op::ViewCopy((*result)[i], (*params.y)[i], executorPtr);
|
||||
CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
|
||||
}
|
||||
// Standard syntax, get the size of workspace needed during computation.
|
||||
*workspaceSize = uniqueExecutor->GetWorkspaceSize();
|
||||
uniqueExecutor.ReleaseTo(executor);
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
static aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(const aclTensorList *x, const aclTensorList *weight,
|
||||
const aclTensor *groupList, bool transposeWeight, GMMWeightVersion weightVersion,
|
||||
const aclTensorList *y, uint64_t *workspaceSize, aclOpExecutor **executor) {
|
||||
DataType xDtype = DataType::DT_UNDEFINED;
|
||||
for (size_t i = 0; i < x->Size(); ++i) {
|
||||
if ((*x)[i] != nullptr) {
|
||||
xDtype = (*x)[i]->GetDataType();
|
||||
break;
|
||||
}
|
||||
}
|
||||
MoeGroupedMatmulParams moeGmmParams{x, weight, groupList, transposeWeight, true, weightVersion, y, xDtype};
|
||||
aclnnStatus ret = GetGMMResultByL0Api(moeGmmParams, workspaceSize, executor);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
|
||||
const aclTensor *groupList, bool transposeWeight, aclTensorList *out,
|
||||
uint64_t *workspaceSize, aclOpExecutor **executor) {
|
||||
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
|
||||
"one of required inputs is nullptr.");
|
||||
// Standard syntax, Check parameters.
|
||||
L2_DFX_PHASE_1(aclnnMoeGroupedMatmulWeightNz,
|
||||
DFX_IN(x, weight, groupList),
|
||||
DFX_OUT(out));
|
||||
return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNz, out, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize(const aclTensorList *x, const aclTensorList *weight,
|
||||
const aclTensor *groupList, bool transposeWeight, aclTensorList *out,
|
||||
uint64_t *workspaceSize, aclOpExecutor **executor) {
|
||||
CHECK_COND(CheckNotNull(x, weight, out) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR,
|
||||
"one of required inputs is nullptr.");
|
||||
// Standard syntax, Check parameters.
|
||||
L2_DFX_PHASE_1(aclnnMoeGroupedMatmul,
|
||||
DFX_IN(x, weight, groupList),
|
||||
DFX_OUT(out));
|
||||
CHECK_COND(weight->Size() != 0, ACLNN_ERR_PARAM_INVALID, "weight should not be null tensorlist ");
|
||||
return aclnnMoeGroupedMatmulGetWorkspaceSizeCommon(x, weight, groupList, transposeWeight, GMMWeightVersion::WeightNd, out, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeGroupedMatmul(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
|
||||
aclrtStream stream) {
|
||||
L2_DFX_PHASE_2(aclnnMoeGroupedMatmul);
|
||||
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
|
||||
"This is an error in GMM launch aicore");
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeGroupedMatmulWeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
|
||||
aclrtStream stream) {
|
||||
L2_DFX_PHASE_2(aclnnMoeGroupedMatmulWeightNz);
|
||||
CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER,
|
||||
"This is an error in GMM launch aicore");
|
||||
return ACLNN_SUCCESS;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
30
csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.h
Normal file
30
csrc/moe_grouped_matmul/op_host/aclnn_moe_grouped_matmul.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef OP_API_INC_MOE_GROUPED_MATMUL_H
|
||||
#define OP_API_INC_MOE_GROUPED_MATMUL_H
|
||||
#include "aclnn/aclnn_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulGetWorkspaceSize(
|
||||
const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList,
|
||||
bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H
|
||||
#define OP_API_INC_MOE_GROUPED_MATMUL_WEIGHT_NZ_H
|
||||
#include "aclnn/aclnn_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNzGetWorkspaceSize(
|
||||
const aclTensorList *x, const aclTensorList *weight, const aclTensor *groupList,
|
||||
bool transposeWeight, aclTensorList *out, uint64_t *workspaceSize, aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeGroupedMatmulWeightNz(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
336
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_cpu.cpp
Normal file
336
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_cpu.cpp
Normal file
@@ -0,0 +1,336 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include "moe_grouped_matmul_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
|
||||
#define OP_LOGD(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n")
|
||||
#define OP_LOGE(nodeName, fmt, ...) printf(fmt, ##__VA_ARGS__); printf("\n")
|
||||
|
||||
constexpr uint32_t X_INDEX = 0;
|
||||
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||
constexpr uint32_t GROUPLIST_INDEX = 2;
|
||||
namespace optiling {
|
||||
constexpr uint64_t BEST_L1_PARTA = 256UL * 1024UL;
|
||||
constexpr uint64_t BEST_L1_PARTB = 128UL * 1024UL;
|
||||
constexpr uint64_t L1_PARTA_SIZE = 256UL * 1024UL;
|
||||
constexpr int32_t BEST_BASEN = 256;
|
||||
constexpr uint64_t DOUBLE_BUFFER_L0A_L0B = 2;
|
||||
constexpr uint64_t DOUBLE_BUFFER_STEPKA_STEPKB = 2;
|
||||
constexpr uint32_t FP32_DATATYPE_SIZE = 4;
|
||||
constexpr int32_t MAX_BASEM = 256;
|
||||
|
||||
static inline uint32_t SixteenAlign(uint32_t a, bool up = false) {
|
||||
if (up) {
|
||||
a += 15U; // 15: 16 bytes up-align
|
||||
}
|
||||
return a & ~15U; // ~15: 16 bytes down-align
|
||||
}
|
||||
|
||||
static inline int64_t SixteenAlign(int64_t a, bool up = false) {
|
||||
if (up) {
|
||||
a += 15; // 15: 16 bytes up-align
|
||||
}
|
||||
return a & ~15; // ~15: 16 bytes down-align
|
||||
}
|
||||
|
||||
class TilingMoeGroupedMatmulFunc {
|
||||
public:
|
||||
explicit TilingMoeGroupedMatmulFunc(gert::TilingContext* tiling_context)
|
||||
: tiling_context_(tiling_context) {}
|
||||
|
||||
ge::graphStatus Init();
|
||||
ge::graphStatus RunKernelTiling();
|
||||
|
||||
private:
|
||||
MoeGroupedMatmulTilingData tiling_data_;
|
||||
gert::TilingContext* tiling_context_ = nullptr;
|
||||
|
||||
void SetTilingKey();
|
||||
void FillTilingData();
|
||||
ge::graphStatus CalMMTiling();
|
||||
ge::graphStatus GMMSetMMTiling();
|
||||
ge::graphStatus CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb);
|
||||
ge::graphStatus DynamicTIlingSingleN();
|
||||
|
||||
void InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo);
|
||||
void GMMGetPlatformInfo();
|
||||
int64_t m_ = 0L;
|
||||
int64_t n_ = 0L;
|
||||
int64_t k_ = 0L;
|
||||
bool transpose_weight = false;
|
||||
bool weight_nz = false;
|
||||
uint32_t single_m_ = 0;
|
||||
uint32_t single_n_ = 0;
|
||||
int32_t baseM_ = 0;
|
||||
int32_t baseN_ = 0;
|
||||
int32_t baseK_ = 0;
|
||||
int32_t nz_factor_ = 1;
|
||||
uint32_t group_num_ = 0;
|
||||
uint32_t core_num_ = 0;
|
||||
uint32_t mmDataTypeSize_ = 0;
|
||||
size_t sync_workspace_size_ = 0;
|
||||
ge::DataType x_dtype;
|
||||
uint64_t l1_size, l0a_size, l0b_size, l0c_size, ub_size;
|
||||
uint32_t aic_num, aiv_num;
|
||||
// SocVersion soc_version;
|
||||
};
|
||||
|
||||
ge::graphStatus TilingMoeGroupedMatmulFunc::CalcStepKaKb(uint32_t& mm_step_ka, uint32_t& mm_step_kb) {
|
||||
uint64_t available_l1_size = l1_size;
|
||||
if (available_l1_size < L1_PARTA_SIZE) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "available_l1_size is less than 256k.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
// according to double buffer, recompute the params used for data movement from GM to L1
|
||||
uint64_t l1_a_size = baseM_ > baseN_ ? L1_PARTA_SIZE : available_l1_size - L1_PARTA_SIZE;
|
||||
uint64_t l1_b_size = available_l1_size - l1_a_size;
|
||||
// 2: double buffer
|
||||
mm_step_ka = (l1_a_size / 2UL) / (static_cast<uint64_t>(baseM_) * baseK_ * mmDataTypeSize_);
|
||||
// 2: double buffer
|
||||
mm_step_kb = (l1_b_size / 2UL) / (static_cast<uint64_t>(baseN_) * baseK_ * mmDataTypeSize_);
|
||||
if (mm_step_ka == 0 || mm_step_kb == 0) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "stepka or stepkb cannot be 0.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (mm_step_ka > mm_step_kb) {
|
||||
mm_step_ka = mm_step_ka / mm_step_kb * mm_step_kb;
|
||||
} else if (mm_step_ka < mm_step_kb) {
|
||||
mm_step_kb = mm_step_kb / mm_step_ka * mm_step_ka;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void TilingMoeGroupedMatmulFunc::GMMGetPlatformInfo() {
|
||||
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
|
||||
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1_size);
|
||||
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0a_size);
|
||||
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0b_size);
|
||||
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0c_size);
|
||||
platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size);
|
||||
aic_num = platform_info.GetCoreNumAic();
|
||||
aiv_num = platform_info.GetCoreNumAiv();
|
||||
}
|
||||
|
||||
void TilingMoeGroupedMatmulFunc::InitPlatformInfo(matmul_tiling::PlatformInfo& platformInfo) {
|
||||
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
|
||||
platformInfo.socVersion = platform_info.GetSocVersion();
|
||||
platformInfo.l1Size = l1_size;
|
||||
platformInfo.l0CSize = l0c_size;
|
||||
platformInfo.ubSize = ub_size;
|
||||
platformInfo.l0ASize = l0a_size;
|
||||
platformInfo.l0BSize = l0b_size;
|
||||
}
|
||||
|
||||
ge::graphStatus TilingMoeGroupedMatmulFunc::GMMSetMMTiling() {
|
||||
matmul_tiling::DataType matmul_dtype = static_cast<matmul_tiling::DataType>(x_dtype);
|
||||
matmul_tiling::PlatformInfo platformInfo;
|
||||
InitPlatformInfo(platformInfo);
|
||||
// matmul_tiling::MatmulApiTiling mm(platformInfo);
|
||||
matmul_tiling::MultiCoreMatmulTiling mm(platformInfo);
|
||||
mm.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, false);
|
||||
if (weight_nz) {
|
||||
mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_dtype, transpose_weight);
|
||||
} else {
|
||||
mm.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype, transpose_weight);
|
||||
}
|
||||
mm.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_dtype);
|
||||
mm.SetOrgShape(m_, n_, k_);
|
||||
mm.SetShape(m_, baseN_, k_);
|
||||
// mm.SetShape(single_m_, single_n_, k_);
|
||||
mm.SetFixSplit(baseM_, baseN_, baseK_);
|
||||
mm.SetBufferSpace(l1_size, l0c_size, ub_size);
|
||||
|
||||
if (mm.GetTiling(tiling_data_.mm_tiling) == -1) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "matmul getTiling failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
uint32_t mm_step_ka = 1;
|
||||
uint32_t mm_step_kb = 1;
|
||||
|
||||
auto ret = CalcStepKaKb(mm_step_ka, mm_step_kb);
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "matmul calc stepka or stepkb failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
constexpr uint32_t step_m = 1; // 1: step_m set fixed value 1
|
||||
constexpr uint32_t step_n = 1; // 1: step_n set fixed value 1
|
||||
uint32_t mm_depth_a1 = mm_step_ka * DOUBLE_BUFFER_STEPKA_STEPKB * step_m;
|
||||
uint32_t mm_depth_b1 = mm_step_kb * DOUBLE_BUFFER_STEPKA_STEPKB * step_n;
|
||||
tiling_data_.mm_tiling.set_shareMode(0);
|
||||
tiling_data_.mm_tiling.set_dbL0C(1); // disable double buffer for LOC
|
||||
tiling_data_.mm_tiling.set_baseM(baseM_); // set precomputed baseM
|
||||
tiling_data_.mm_tiling.set_baseN(baseN_); // set precomputed baseN
|
||||
tiling_data_.mm_tiling.set_baseK(baseK_); // set precomputed baseK
|
||||
tiling_data_.mm_tiling.set_stepKa(mm_step_ka); // set precomputed mmStepKa
|
||||
tiling_data_.mm_tiling.set_depthA1(mm_depth_a1); // set precomputed mmDepthA1
|
||||
tiling_data_.mm_tiling.set_stepKb(mm_step_kb); // set precomputed mmStepKb
|
||||
tiling_data_.mm_tiling.set_depthB1(mm_depth_b1); // set precomputed mmDepthB1
|
||||
tiling_data_.mm_tiling.set_stepM(step_m); // set precomputed stepM
|
||||
tiling_data_.mm_tiling.set_stepN(step_n); // set precomputed stepN
|
||||
OP_LOGD(context->GetNodeName(), "GMM_tiling: baseM is %d, baseK is %d, baseN is %d, transpose_weight is %d, weight_nz is %d", baseM_, baseK_, baseN_, transpose_weight, weight_nz);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus TilingMoeGroupedMatmulFunc::CalMMTiling() {
|
||||
baseN_ = BEST_BASEN;
|
||||
if (x_dtype == ge::DT_BF16 || x_dtype == ge::DT_FLOAT16) {
|
||||
mmDataTypeSize_ = 2;
|
||||
} else {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "only support bf16 or fp16.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
baseK_ = static_cast<int32_t>((l0b_size / DOUBLE_BUFFER_L0A_L0B) / (static_cast<uint32_t>(baseN_) * mmDataTypeSize_));
|
||||
baseK_ = static_cast<int32_t>(SixteenAlign(static_cast<int64_t>(baseK_)));
|
||||
uint32_t max_base_m = static_cast<uint32_t>(l0c_size /
|
||||
(static_cast<uint32_t>(baseN_) * FP32_DATATYPE_SIZE));
|
||||
baseM_ = std::min<uint32_t>((l0a_size / DOUBLE_BUFFER_L0A_L0B) /
|
||||
(static_cast<uint32_t>(baseK_) * mmDataTypeSize_), max_base_m);
|
||||
baseM_ = baseM_ > m_ ? SixteenAlign(m_, true) : SixteenAlign(static_cast<uint32_t>(baseM_));
|
||||
|
||||
if (baseM_ > MAX_BASEM) {
|
||||
baseM_ = MAX_BASEM;
|
||||
}
|
||||
|
||||
if (baseM_ == 0 || baseK_ == 0) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "baseM_ or baseN_ cannot be 0.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus TilingMoeGroupedMatmulFunc::Init() {
|
||||
GMMGetPlatformInfo();
|
||||
// only support singlex、singleweight、singley
|
||||
bool is_single_x = (tiling_context_->GetDynamicInputTensor(X_INDEX, 1) == nullptr);
|
||||
bool is_single_weight = (tiling_context_->GetDynamicInputTensor(WEIGHT_INDEX, 1) == nullptr);
|
||||
bool is_single_y = (tiling_context_->GetOutputShape(1) == nullptr);
|
||||
transpose_weight = static_cast<bool>(*(tiling_context_->GetAttrs()->GetAttrPointer<bool>(0)));
|
||||
if (!(is_single_x && is_single_weight && is_single_y)) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "only support singlex and singleweight and singley.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto x_shape = tiling_context_->GetDynamicInputShape(X_INDEX, 0)->GetOriginShape();
|
||||
auto weight_shape = tiling_context_->GetDynamicInputShape(WEIGHT_INDEX, 0)->GetOriginShape();
|
||||
auto group_list_shape = tiling_context_->GetInputShape(GROUPLIST_INDEX)->GetOriginShape();
|
||||
auto y_shape = tiling_context_->GetOutputShape(0)->GetOriginShape();
|
||||
auto weight_desc = tiling_context_->GetDynamicInputDesc(WEIGHT_INDEX, 0);
|
||||
auto weight_format = static_cast<ge::Format>(ge::GetPrimaryFormat(weight_desc->GetStorageFormat()));
|
||||
x_dtype = tiling_context_->GetDynamicInputDesc(X_INDEX, 0)->GetDataType();
|
||||
|
||||
// printf("weight_format %d\n", weight_format);
|
||||
weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ;
|
||||
|
||||
// check input shape
|
||||
if (x_shape.GetDimNum() != 2 || y_shape.GetDimNum() != 2) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "the dimNum of input and output should be 2, but got %zu, %zu.", static_cast<size_t>(x_shape.GetDimNum()), static_cast<size_t>(y_shape.GetDimNum()));
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
uint32_t weight_dim1, weight_dim2;
|
||||
n_ = transpose_weight ? weight_shape.GetDim(1) : weight_shape.GetDim(2);
|
||||
|
||||
if (group_list_shape.GetDimNum() != 2) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "only support key-value mode of groupList, the dimNum of groupList should be 2, but got %zu.", static_cast<size_t>(group_list_shape.GetDimNum()));
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
m_ = x_shape.GetDim(x_shape.GetDimNum() - 2);
|
||||
k_ = x_shape.GetDim(x_shape.GetDimNum() - 1);
|
||||
group_num_ = weight_shape.GetDim(0);
|
||||
|
||||
if (weight_shape.GetDim(0) != group_num_) {
|
||||
OP_LOGE(tiling_context_->GetNodeName(), "the dim0 of input weight should be equal to input groupList, but got %zu, %zu.", static_cast<size_t>(weight_shape.GetDim(0)), static_cast<size_t>(group_list_shape.GetDim(0)));
|
||||
}
|
||||
single_m_ = 128;
|
||||
single_n_ = 256;
|
||||
|
||||
core_num_ = aic_num;
|
||||
auto n_task_num = (n_ + single_n_ - 1) / single_n_;
|
||||
auto task_num = m_ * n_task_num;
|
||||
if (task_num < core_num_) {
|
||||
core_num_ = task_num;
|
||||
}
|
||||
|
||||
auto platform_info = platform_ascendc::PlatformAscendC(tiling_context_->GetPlatformInfo());
|
||||
sync_workspace_size_ = static_cast<size_t>(platform_info.GetLibApiWorkSpaceSize());
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void TilingMoeGroupedMatmulFunc::SetTilingKey() {
|
||||
uint64_t tiling_key = 10UL;
|
||||
if (transpose_weight) {
|
||||
tiling_key = tiling_key + 1UL;
|
||||
}
|
||||
tiling_context_->SetTilingKey(tiling_key);
|
||||
}
|
||||
|
||||
void TilingMoeGroupedMatmulFunc::FillTilingData() {
|
||||
tiling_data_.set_m(static_cast<uint32_t>(m_));
|
||||
tiling_data_.set_n(static_cast<uint32_t>(n_));
|
||||
tiling_data_.set_k(static_cast<uint32_t>(k_));
|
||||
tiling_data_.set_single_m(single_m_);
|
||||
tiling_data_.set_single_n(single_n_);
|
||||
tiling_data_.set_group_num(group_num_);
|
||||
tiling_data_.set_core_num(core_num_);
|
||||
}
|
||||
|
||||
ge::graphStatus TilingMoeGroupedMatmulFunc::RunKernelTiling() {
|
||||
auto ret = CalMMTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
OP_LOGE(context->GetNodeName(), "cal mmtiling failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
ret = GMMSetMMTiling();
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
OP_LOGE(context->GetNodeName(), "gmm set mmtiling failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
SetTilingKey();
|
||||
FillTilingData();
|
||||
size_t userWorkspaceSize = 0;
|
||||
size_t *currentWorkspace = tiling_context_->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = userWorkspaceSize + sync_workspace_size_;
|
||||
tiling_data_.SaveToBuffer(tiling_context_->GetRawTilingData()->GetData(),
|
||||
tiling_context_->GetRawTilingData()->GetCapacity());
|
||||
tiling_context_->GetRawTilingData()->SetDataSize(tiling_data_.GetDataSize());
|
||||
tiling_context_->SetBlockDim(core_num_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingForMoeGroupedMatmulFunc(gert::TilingContext *context){
|
||||
TilingMoeGroupedMatmulFunc tilingObject(context);
|
||||
auto ret = tilingObject.Init();
|
||||
if(ret != ge::GRAPH_SUCCESS){
|
||||
OP_LOGE(context->GetNodeName(), "tiling Init failed.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
ret = tilingObject.RunKernelTiling();
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct MatmulAllreduceAddRmsnormCompileInfo1 {};
|
||||
ge::graphStatus TilingParseForMatmulAllreduceAddRmsnorm1(gert::TilingParseContext *context)
|
||||
{
|
||||
// (void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
IMPL_OP_OPTILING(MoeGroupedMatmul)
|
||||
.Tiling(TilingForMoeGroupedMatmulFunc)
|
||||
.TilingParse<MatmulAllreduceAddRmsnormCompileInfo1>(TilingParseForMatmulAllreduceAddRmsnorm1);
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
53
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_def.cpp
Normal file
53
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_def.cpp
Normal file
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_grouped_matmul_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
#include "moe_grouped_matmul_infershape.cpp"
|
||||
|
||||
namespace ops {
|
||||
class MoeGroupedMatmul : public OpDef {
|
||||
public:
|
||||
explicit MoeGroupedMatmul(const char *name) : OpDef(name) {
|
||||
this->Input("x")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("weight")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("group_list")
|
||||
.ParamType(REQUIRED)
|
||||
.DataTypeList({ge::DT_INT64, ge::DT_INT32})
|
||||
.FormatList({ge::FORMAT_ND});
|
||||
this->Output("y")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("transpose_weight")
|
||||
.AttrType(OPTIONAL)
|
||||
.Bool(false);
|
||||
this->SetInferShape(ge::InferShape);
|
||||
this->SetInferDataType(ge::InferDataType);
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
OP_ADD(MoeGroupedMatmul);
|
||||
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,35 @@
|
||||
#include "register/op_impl_registry.h"
|
||||
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ge {
|
||||
constexpr uint32_t X_INDEX = 0;
|
||||
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||
constexpr uint32_t GROUPLIST_INDEX = 2;
|
||||
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
|
||||
const gert::Shape* x_shape = context->GetDynamicInputShape(X_INDEX, 0);
|
||||
const gert::Shape* weight_shape = context->GetDynamicInputShape(WEIGHT_INDEX, 0);
|
||||
bool transpose_weight = static_cast<bool>(*(context->GetAttrs()->GetAttrPointer<bool>(0)));
|
||||
gert::Shape* y_shape = context->GetOutputShape(0);
|
||||
*y_shape = *x_shape;
|
||||
auto weight_desc = context->GetDynamicInputDesc(WEIGHT_INDEX, 0);
|
||||
auto weight_format = static_cast<ge::Format>(ge::GetPrimaryFormat(weight_desc->GetStorageFormat()));
|
||||
bool weight_nz = weight_format == ge::FORMAT_FRACTAL_NZ;
|
||||
int64_t dim_n;
|
||||
if (weight_nz) {
|
||||
dim_n = transpose_weight ? (weight_shape->GetDim(1) * weight_shape->GetDim(3)) :
|
||||
(weight_shape->GetDim(2) * weight_shape->GetDim(4));
|
||||
} else {
|
||||
dim_n = transpose_weight ? weight_shape->GetDim(1) : weight_shape->GetDim(2);
|
||||
}
|
||||
y_shape->SetDim(1, dim_n);
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) {
|
||||
const auto input_dtype = context->GetDynamicInputDataType(X_INDEX, 0);
|
||||
context->SetOutputDataType(0, input_dtype);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
||||
50
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.cpp
Normal file
50
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "moe_grouped_matmul_l0.h"
|
||||
#include "opdev/op_log.h"
|
||||
#include "opdev/op_dfx.h"
|
||||
#include "opdev/shape_utils.h"
|
||||
#include "opdev/make_op_executor.h"
|
||||
|
||||
using namespace op;
|
||||
|
||||
namespace l0op {
|
||||
OP_TYPE_REGISTER(MoeGroupedMatmul);
|
||||
|
||||
const aclTensorList *MoeGroupedMatmul(const aclTensorList *x,
|
||||
const aclTensorList *weight,
|
||||
const aclTensor *groupList,
|
||||
bool transposeWeight,
|
||||
op::Shape yShape,
|
||||
size_t outLength,
|
||||
op::DataType yDtype,
|
||||
aclOpExecutor *executor) {
|
||||
L0_DFX(MoeGroupedMatmul, x, weight, groupList, transposeWeight, outLength);
|
||||
|
||||
std::vector<const aclTensor*> tensorsVec;
|
||||
const aclTensor *x0 = x->Size() > 0 ? (*x)[0] : nullptr;
|
||||
if (x0 == nullptr) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "(*x)[0] is nullptr.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (size_t i(0); i < outLength; ++i) {
|
||||
tensorsVec.emplace_back(executor->AllocTensor(yShape, yDtype));
|
||||
}
|
||||
auto out = executor->AllocTensorList(tensorsVec.data(), outLength);
|
||||
|
||||
auto x0_dim_num = x0->GetStorageShape().GetDimNum();
|
||||
auto x0_dim0 = x0->GetStorageShape().GetDim(0);
|
||||
printf("x0_dim_num %d x0_dim0 %d\n", x0_dim_num, x0_dim0);
|
||||
|
||||
auto ret = ADD_TO_LAUNCHER_LIST_AICORE(MoeGroupedMatmul,
|
||||
OP_INPUT(x, weight, groupList),
|
||||
OP_OUTPUT(out),
|
||||
OP_ATTR(transposeWeight));
|
||||
if (ret != ACLNN_SUCCESS) {
|
||||
OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace l0op
|
||||
27
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.h
Normal file
27
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_l0.h
Normal file
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H
|
||||
#define OP_API_INC_LEVEL0_OP_MOE_GROUPED_MATMUL_OP_H
|
||||
|
||||
#include "opdev/op_executor.h"
|
||||
|
||||
namespace l0op {
|
||||
const aclTensorList *MoeGroupedMatmul(const aclTensorList *x,
|
||||
const aclTensorList *weight,
|
||||
const aclTensor *groupList,
|
||||
bool transposeWeight,
|
||||
op::Shape yShape,
|
||||
size_t outLength,
|
||||
op::DataType yDtype,
|
||||
aclOpExecutor *executor);
|
||||
}
|
||||
|
||||
#endif
|
||||
26
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_tiling.h
Normal file
26
csrc/moe_grouped_matmul/op_host/moe_grouped_matmul_tiling.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(MoeGroupedMatmulTilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, group_num);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, core_num);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, m);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, n);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, k);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, single_m);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, single_n);
|
||||
TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mm_tiling);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(MoeGroupedMatmul, MoeGroupedMatmulTilingData)
|
||||
} // namespace optiling
|
||||
41
csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.cpp
Normal file
41
csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include "moe_grouped_matmul.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#if defined(FORMAT_WEIGHT) && FORMAT_WEIGHT == FORMAT_FRACTAL_NZ
|
||||
constexpr CubeFormat formatWeight = CubeFormat::NZ;
|
||||
#else
|
||||
constexpr CubeFormat formatWeight = CubeFormat::ND;
|
||||
#endif
|
||||
|
||||
//using namespace matmul;
|
||||
#define GMM_CUBE_IMP(transWeight) \
|
||||
do { \
|
||||
if ASCEND_IS_AIV { \
|
||||
return; \
|
||||
} \
|
||||
GET_TILING_DATA(tiling_data, tiling); \
|
||||
AscendC::TPipe pipe; \
|
||||
KernelMoeGMMNoQuant<DTYPE_X, DTYPE_GROUP_LIST, formatWeight, transWeight> op(&pipe); \
|
||||
op.Init(x, weight, group_list, y, &tiling_data); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
extern "C" __global__ __aicore__ void moe_grouped_matmul(GM_ADDR x, GM_ADDR weight, GM_ADDR group_list, GM_ADDR y,
|
||||
GM_ADDR workSpace, GM_ADDR tiling) {
|
||||
|
||||
if (TILING_KEY_IS(10UL)) {
|
||||
GMM_CUBE_IMP(false);
|
||||
} else if (TILING_KEY_IS(11UL)) {
|
||||
GMM_CUBE_IMP(true);
|
||||
}
|
||||
}
|
||||
186
csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.h
Normal file
186
csrc/moe_grouped_matmul/op_kernel/moe_grouped_matmul.h
Normal file
@@ -0,0 +1,186 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_operator_list_tensor_intf.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
using namespace AscendC;
|
||||
|
||||
|
||||
constexpr MatmulConfig matmulCFGUnitFlag{false, false, true, 0, 0, 0, false, false, false, false, false, 0, 0, 0,
|
||||
0, 0, 0, 0, true};
|
||||
struct GMMConfig {
|
||||
uint32_t m = 0;
|
||||
uint32_t k = 0;
|
||||
uint32_t n = 0;
|
||||
uint32_t baseM = 0;
|
||||
uint32_t baseN = 0;
|
||||
uint32_t mIdx = 0;
|
||||
uint32_t nIdx = 0;
|
||||
uint32_t blockDimM = 0;
|
||||
uint32_t blockDimN = 0;
|
||||
uint32_t singleM = 0;
|
||||
uint32_t singleN = 0;
|
||||
uint64_t wBaseOffset = 0;
|
||||
uint64_t nAxisBaseOffset = 0;
|
||||
uint64_t mAxisBaseOffset = 0;
|
||||
uint64_t xBaseOffset = 0;
|
||||
uint64_t yBaseOffset = 0;
|
||||
uint64_t wOutOffset = 0;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename T2, CubeFormat formatWeight, bool transWeight>
|
||||
class KernelMoeGMMNoQuant {
|
||||
|
||||
protected:
|
||||
using xType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T, false>;
|
||||
using weightType = MatmulType<AscendC::TPosition::GM, formatWeight, T, transWeight>;
|
||||
using yType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, T>;
|
||||
using biasType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>;
|
||||
using mmT = matmul::MatmulImpl<xType, weightType, yType, biasType, matmulCFGUnitFlag>;
|
||||
mmT mm;
|
||||
|
||||
MoeGroupedMatmulTilingData tiling_;
|
||||
AscendC::TPipe *pipe_ = nullptr;
|
||||
|
||||
GlobalTensor<T> x_gm_;
|
||||
GlobalTensor<T> weight_gm_;
|
||||
GlobalTensor<T> y_gm_;
|
||||
GlobalTensor<T2> group_list_gm_;
|
||||
ListTensorDesc x_list_;
|
||||
ListTensorDesc weight_list_;
|
||||
ListTensorDesc y_list_;
|
||||
|
||||
uint32_t core_idx;
|
||||
uint32_t used_core_num;
|
||||
constexpr static bool transposeW = transWeight;
|
||||
constexpr static uint32_t UB_BLOCK_UNIT_SIZE = 32;
|
||||
|
||||
public:
|
||||
__aicore__ inline KernelMoeGMMNoQuant(AscendC::TPipe *pipe) {pipe_ = pipe;}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR group_list, GM_ADDR y, const MoeGroupedMatmulTilingData *tiling) {
|
||||
core_idx = GetBlockIdx();
|
||||
tiling_ = *tiling;
|
||||
used_core_num = GetBlockNum();
|
||||
group_list_gm_.SetGlobalBuffer((__gm__ T2*)group_list);
|
||||
x_list_.Init((__gm__ void*)x);
|
||||
weight_list_.Init((__gm__ void*)weight);
|
||||
y_list_.Init((__gm__ void*)y);
|
||||
GM_ADDR x_first_addr = (__gm__ uint8_t*)x_list_.GetDataPtr<__gm__ uint8_t>(0);
|
||||
GM_ADDR weight_first_addr = (__gm__ uint8_t*)weight_list_.GetDataPtr<__gm__ uint8_t>(0);
|
||||
GM_ADDR y_first_addr = (__gm__ uint8_t*)y_list_.GetDataPtr<__gm__ uint8_t>(0);
|
||||
x_gm_.SetGlobalBuffer((__gm__ T*)x_first_addr);
|
||||
weight_gm_.SetGlobalBuffer((__gm__ T*)weight_first_addr);
|
||||
y_gm_.SetGlobalBuffer((__gm__ T*)y_first_addr);
|
||||
|
||||
mm.Init(&tiling_.mm_tiling, pipe_);
|
||||
}
|
||||
|
||||
__aicore__ inline void Process() {
|
||||
|
||||
uint32_t group_list_inner_shape = 2u;
|
||||
uint32_t group_list_shape_size = tiling_.group_num * group_list_inner_shape;
|
||||
GMMConfig mn_config;
|
||||
|
||||
for (uint32_t loop = 0, count = 0; loop < group_list_shape_size; loop += group_list_inner_shape) {
|
||||
int32_t split_value = static_cast<int32_t>(group_list_gm_.GetValue(loop + 1));
|
||||
if (split_value <= 0) {
|
||||
break;
|
||||
}
|
||||
uint32_t group_idx = static_cast<int32_t>(group_list_gm_.GetValue(loop));
|
||||
mn_config.mAxisBaseOffset += mn_config.m;
|
||||
mn_config.xBaseOffset += mn_config.m * mn_config.k;
|
||||
mn_config.yBaseOffset += mn_config.m * mn_config.n;
|
||||
this->SetMNConfig(split_value, mn_config);
|
||||
mn_config.nAxisBaseOffset = group_idx * mn_config.n;
|
||||
if constexpr (formatWeight == CubeFormat::NZ) {
|
||||
mn_config.wBaseOffset = AlignUp(mn_config.k, 16) * AlignUp(mn_config.nAxisBaseOffset, 16);
|
||||
} else {
|
||||
mn_config.wBaseOffset = mn_config.k * mn_config.nAxisBaseOffset;
|
||||
}
|
||||
mn_config.blockDimM = Ceil(mn_config.m, mn_config.singleM);
|
||||
mn_config.blockDimN = Ceil(mn_config.n, mn_config.singleN);
|
||||
uint32_t cur_count = count + mn_config.blockDimM * mn_config.blockDimN;
|
||||
uint32_t cur_block = this->core_idx >= count ? this->core_idx : this->core_idx + used_core_num;
|
||||
while (cur_block < cur_count) {
|
||||
mn_config.mIdx = (cur_block - count) / mn_config.blockDimN;
|
||||
mn_config.nIdx = (cur_block - count) % mn_config.blockDimN;
|
||||
this->MMCompute(group_idx, mn_config, this->core_idx);
|
||||
cur_block += used_core_num;
|
||||
}
|
||||
count = cur_count % used_core_num;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
__aicore__ inline uint32_t AlignUp(uint32_t a, uint32_t base) {
|
||||
return (a + base - 1) / base * base;
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t Ceil(uint32_t a, uint32_t base) {
|
||||
if (base == 0) {
|
||||
return a;
|
||||
}
|
||||
return (a + base - 1) / base;
|
||||
}
|
||||
|
||||
__aicore__ inline void SetMNConfig(const int32_t split_value, GMMConfig & mn_config) {
|
||||
mn_config.m = split_value;
|
||||
mn_config.k = tiling_.k;
|
||||
mn_config.n = tiling_.n;
|
||||
mn_config.baseM = tiling_.single_m;
|
||||
mn_config.baseN = tiling_.single_n;
|
||||
mn_config.singleM = mn_config.baseM;
|
||||
mn_config.singleN = mn_config.baseN;
|
||||
}
|
||||
|
||||
__aicore__ inline void MMCompute(uint32_t group_idx, GMMConfig& mn_config, uint32_t core_idx) {
|
||||
uint32_t tail_n = mn_config.nIdx * mn_config.singleN;
|
||||
uint32_t cur_single_n = mn_config.nIdx < mn_config.blockDimN - 1 ? mn_config.singleN : mn_config.n - tail_n;
|
||||
uint32_t cur_single_m = mn_config.mIdx < mn_config.blockDimM - 1 ? mn_config.singleM
|
||||
: mn_config.m - mn_config.mIdx * mn_config.singleM;
|
||||
uint64_t x_offset = mn_config.mIdx * mn_config.singleM * mn_config.k;
|
||||
uint64_t out_offset = mn_config.mIdx * mn_config.singleM * mn_config.n + tail_n;
|
||||
GlobalTensor<T> weight_gm_local = GetGlobalBufferW(group_idx, tail_n, mn_config);
|
||||
|
||||
mm.SetOrgShape(mn_config.m, mn_config.n, mn_config.k);
|
||||
mm.SetSingleShape(cur_single_m, cur_single_n, mn_config.k);
|
||||
mm.SetTensorA(x_gm_[mn_config.xBaseOffset + x_offset], false);
|
||||
mm.SetTensorB(weight_gm_local, transposeW);
|
||||
mm.template IterateAll<false>(y_gm_[mn_config.yBaseOffset + out_offset], 0);
|
||||
}
|
||||
|
||||
__aicore__ inline GlobalTensor<T> GetGlobalBufferW(uint32_t group_idx, uint32_t tail_n, GMMConfig& mn_config) {
|
||||
uint64_t w_offset = SetWOffset(tail_n, mn_config.k);
|
||||
GlobalTensor<T> weight_gm_local;
|
||||
weight_gm_local = weight_gm_[mn_config.wBaseOffset + w_offset];
|
||||
if (mn_config.blockDimM == 1) {
|
||||
weight_gm_local.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
return weight_gm_local;
|
||||
}
|
||||
|
||||
__aicore__ inline uint64_t SetWOffset(uint32_t tail_n, uint32_t k) {
|
||||
uint64_t w_offset = 0;
|
||||
if constexpr (formatWeight == CubeFormat::NZ && transposeW) {
|
||||
w_offset = tail_n * (UB_BLOCK_UNIT_SIZE / sizeof(T)); // 32: quant is 32, float16 is 16
|
||||
} else if constexpr (formatWeight == CubeFormat::NZ) {
|
||||
w_offset = tail_n * AlignUp(k, 16); // 16: nz format last two dim size
|
||||
} else if constexpr (transposeW) {
|
||||
w_offset = tail_n * k;
|
||||
} else {
|
||||
w_offset = tail_n;
|
||||
}
|
||||
return w_offset;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -597,6 +597,38 @@ void transpose_kv_cache_by_block(
|
||||
|
||||
}
|
||||
|
||||
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
||||
std::vector<at::Tensor> moe_grouped_matmul(
|
||||
at::Tensor x,
|
||||
at::Tensor weight,
|
||||
const at::Tensor& group_list,
|
||||
int64_t split_item,
|
||||
int64_t group_type,
|
||||
int64_t group_list_type
|
||||
)
|
||||
{
|
||||
bool transpose_weight = false;
|
||||
bool weight_nz = true;
|
||||
|
||||
at::TensorList x_list = at::TensorList(x);
|
||||
at::TensorList weight_list = at::TensorList(weight);
|
||||
std::vector<at::Tensor> y;
|
||||
c10::TensorOptions options = x_list[0].options().dtype(x[0].scalar_type());
|
||||
auto m = x_list[0].sizes()[0];
|
||||
auto n = weight_list[0].sizes()[1];
|
||||
if (!transpose_weight) {
|
||||
n = weight_list[0].sizes()[2];
|
||||
}
|
||||
at::Tensor y_0 = at::empty(at::IntArrayRef{m, n}, options);
|
||||
y.emplace_back(y_0);
|
||||
at::TensorList result = at::TensorList(y);
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz,
|
||||
x_list, weight_list, group_list, transpose_weight, result);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -779,4 +811,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||
);
|
||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||
ops.def(
|
||||
"moe_grouped_matmul("
|
||||
"Tensor x,"
|
||||
"Tensor weight,"
|
||||
"Tensor group_list,"
|
||||
"int split_item,"
|
||||
"int group_type,"
|
||||
"int group_list_type)"
|
||||
|
||||
"-> Tensor[]"
|
||||
);
|
||||
ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul);
|
||||
}
|
||||
|
||||
@@ -457,6 +457,35 @@ void transpose_kv_cache_by_block_meta(
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> moe_grouped_matmul_meta(
|
||||
at::Tensor x,
|
||||
at::Tensor weight,
|
||||
const at::Tensor& group_list,
|
||||
int64_t split_item,
|
||||
int64_t group_type,
|
||||
int64_t group_list_type
|
||||
)
|
||||
{
|
||||
bool transpose_weight = false;
|
||||
bool weight_nz = true;
|
||||
|
||||
at::TensorList x_list = at::TensorList(x);
|
||||
at::TensorList weight_list = at::TensorList(weight);
|
||||
std::vector<at::Tensor> y;
|
||||
c10::TensorOptions options = x[0].options().dtype(x[0].scalar_type());
|
||||
auto m = x[0].sizes()[0];
|
||||
auto n = weight[0].sizes()[1];
|
||||
if (!transpose_weight) {
|
||||
n = weight[0].sizes()[2];
|
||||
}
|
||||
at::Tensor y_0 = at::zeros(at::IntArrayRef{m, n}, options);
|
||||
y.emplace_back(y_0);
|
||||
at::TensorList result = at::TensorList(y);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -498,5 +527,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
||||
// transpose_kv_cache_by_block
|
||||
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
|
||||
// moe_grouped_matmul
|
||||
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
||||
}
|
||||
}
|
||||
|
||||
6
setup.py
6
setup.py
@@ -405,8 +405,10 @@ class cmake_build_ext(build_ext):
|
||||
print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}")
|
||||
|
||||
def run(self):
|
||||
# First, ensure ACLNN custom-ops is built and installed.
|
||||
self.run_command("build_aclnn")
|
||||
if envs.COMPILE_CUSTOM_KERNELS:
|
||||
# First, ensure ACLNN custom-ops is built and installed.
|
||||
self.run_command("build_aclnn")
|
||||
|
||||
# Then, run the standard build_ext command to compile the extensions
|
||||
super().run()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user