[Kernel] add custom op DispatchGmmCombineDecode (#4139)

#### What this PR does / why we need it?
add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl,
python Api, pytest.

vLLM version: v0.11.0
vLLM main:
24d6314718


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
GuoRen868
2025-12-06 17:33:14 +08:00
committed by GitHub
parent cb42564942
commit 4bd1030842
29 changed files with 7851 additions and 27 deletions

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
set(_DISPATCH_GMM_INC_OPTS)
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
list(APPEND _DISPATCH_GMM_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
else()
message(FATAL_ERROR "dependency catlass is missing, you can fetch it by running 'git submodule update --init --recursive'")
endif()
add_ops_compile_options(
OP_NAME DispatchGmmCombineDecode
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
${_DISPATCH_GMM_INC_OPTS}
)
target_sources(op_host_aclnnInner PRIVATE
dispatch_gmm_combine_decode_def.cpp
)
target_sources(opapi PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
target_sources(aclnn_ops_infer PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
endif ()
target_sources(optiling PRIVATE
dispatch_gmm_combine_decode_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
dispatch_gmm_combine_decode_proto.cpp
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,101 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <string.h>
#include "graph/types.h"
#include "aclnn/opdev/platform.h"
#include "aclnn_dispatch_gmm_combine_decode.h"
enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
#ifdef __cplusplus
extern "C" {
#endif
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}
aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
} else {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
}
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,51 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE
#define DISPATCH_GMM_COMBINE_DECODE
#include "aclnn/acl_meta.h"
#ifdef __cplusplus
extern "C" {
#endif
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,83 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/op_def_registry.h"
namespace ops {
class DispatchGmmCombineDecode : public OpDef
{
public:
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();
this->Attr("moe_expert_num").Int();
this->Attr("share_expert_num").Int();
this->Attr("share_expert_rank_num").Int();
this->Attr("quant_mode").Int();
this->Attr("global_bs").Int();
this->MC2().HcclGroup({"group_ep"});
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(DispatchGmmCombineDecode);
} // namespace ops

View File

@@ -0,0 +1,95 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdint>
#include "log/ops_log.h"
#include "error/ops_error.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
namespace ge {
constexpr uint32_t EXPAND_X_INDEX = 0;
constexpr uint32_t EXPERT_IDS_INDEX = 1;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
static ge::graphStatus InferShape(gert::InferShapeContext *context)
{
const char *nodeName = context->GetNodeName();
// infer output shape
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
recvCountOutShape == nullptr) {
return GRAPH_FAILED;
}
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
return GRAPH_FAILED;
}
int bs = expertIdsShape->GetDim(0);
int h = expandXShape->GetDim(1);
expandXOutShape->SetDimNum(expandXShape->GetDimNum());
expandXOutShape->SetDim(0, bs);
expandXOutShape->SetDim(1, h);
// infer recvCount shape
auto attrs = context->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
OPS_ERR_IF(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is nullptr."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(epRankSizePtr == nullptr, OPS_LOG_E(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
OPS_ERR_IF(sharedExpertRankNumPtr == nullptr, OPS_LOG_E(nodeName, "sharedExpertRankNumPtr is nullptr."),
return ge::GRAPH_FAILED);
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
recvCountOutShape->SetDimNum(1);
bool isShareExpert = (epRankId < sharedExpertRankNum);
if (isShareExpert) {
recvCountOutShape->SetDim(0, epRankSize);
} else {
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
}
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
return ge::GRAPH_SUCCESS;
}
IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType);
} // namespace ge

View File

@@ -0,0 +1,335 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdio>
#include <cstdint>
#include <string>
#include "log/ops_log.h"
#include "error/ops_error.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/hccl/hccl_tiling.h"
using namespace ge;
namespace {
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr uint32_t GM_ALIGN_SIZE = 512;
constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2;
constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024;
constexpr uint32_t CUBE_WORKSPACE_STAGE = 4;
constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024;
constexpr uint32_t INPUT_X_INDEX = 0;
constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1;
constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
} // namespace
namespace optiling {
static size_t CeilUp(size_t x, size_t y)
{
return (x + y - 1) / y * y;
}
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
return ge::GRAPH_FAILED);
OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
return ge::GRAPH_FAILED);
uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
OPS_ERR_IF(
tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
OPS_LOG_E(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
return ge::GRAPH_FAILED);
uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
OPS_ERR_IF(
gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
OPS_LOG_E(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
return ge::GRAPH_FAILED);
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
OPS_ERR_IF(topK > SUPPORT_TOP_K, OPS_LOG_E(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
return ge::GRAPH_FAILED);
uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
if (globalBatchSize == 0) {
globalBatchSize = epRankSize * batchSize;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize;
} else {
OPS_ERR_IF(globalBatchSize < 0, OPS_LOG_E(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(globalBatchSize % epRankSize > 0,
OPS_LOG_E(nodeName, "globalBatchSize must be divisible by epRankSize."),
return ge::GRAPH_FAILED);
}
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t recvAivNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aivNum / 2;
OPS_ERR_IF(
moeExpertNumPerRank > recvAivNum,
OPS_LOG_E(nodeName, "moeExpertNumPerRank must <= (aivNum/2)(%u), but got %u", recvAivNum, moeExpertNumPerRank),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
{
auto attrs = context->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
auto sharedExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_NUM_INDEX);
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
uint32_t sharedExpertNum = static_cast<uint32_t>(*sharedExpertNumPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
OPS_ERR_IF(epRankId < 0, OPS_LOG_E(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(epRankId >= epRankSize, OPS_LOG_E(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OPS_LOG_E(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum <= 0, OPS_LOG_E(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(sharedExpertNum != 1, OPS_LOG_E(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
OPS_LOG_E(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."),
return ge::GRAPH_FAILED);
groupEp = std::string(groupEpPtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank;
return ge::GRAPH_SUCCESS;
}
static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp)
{
const char *nodeName = context->GetNodeName();
OPS_LOG_D(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str());
uint32_t opType = OP_TYPE_ALL_TO_ALL;
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
std::string algConfigAllGatherStr = "AllGather=level0:ring";
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr);
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
}
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
size_t *workSpaces = context->GetWorkspaceSizes(1);
OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
size_t maxTokenNum;
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
uint32_t maxBatchSize = globalBs / epRankSize;
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
if (epRankId < sharedExpertRankNum) {
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
} else {
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
}
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t CVSwapBufferSize =
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context)
{
const char *nodeName = context->GetNodeName();
DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData<DispatchGmmCombineDecodeTilingData>();
OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
std::string groupEp = "";
const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX);
OPS_ERR_IF(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "x shape dims must be 2, but current dim num is %lu.",
xStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize;
const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize;
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX);
OPS_ERR_IF(expertIdsStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIds shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.",
expertIdsStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
context->SetTilingKey(0);
} else {
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
}
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context)
{
ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context);
return ret;
}
struct DispatchGmmCombineDecodeCompileInfo {};
ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(DispatchGmmCombineDecode)
.Tiling(DispatchGmmCombineDecodeTilingFunc)
.TilingParse<DispatchGmmCombineDecodeCompileInfo>(TilingParseForDispatchGmmCombineDecode);
} // namespace optiling