[Kernel] add custom op DispatchGmmCombineDecode (#4139)
#### What this PR does / why we need it? add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl, python Api, pytest. vLLM version: v0.11.0 vLLM main:24d6314718- vLLM version: v0.12.0 - vLLM main:ad32e3e19cSigned-off-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
59
csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt
Normal file
59
csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
set(_DISPATCH_GMM_INC_OPTS)
|
||||
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
list(APPEND _DISPATCH_GMM_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
else()
|
||||
message(FATAL_ERROR "dependency catlass is missing, you can fetch it by running 'git submodule update --init --recursive'")
|
||||
endif()
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME DispatchGmmCombineDecode
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
${_DISPATCH_GMM_INC_OPTS}
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
dispatch_gmm_combine_decode_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
dispatch_gmm_combine_decode_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
dispatch_gmm_combine_decode_proto.cpp
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn/opdev/platform.h"
|
||||
#include "aclnn_dispatch_gmm_combine_decode.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
||||
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
|
||||
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
||||
output, epRecvCount, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
|
||||
} else {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
}
|
||||
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE
|
||||
#define DISPATCH_GMM_COMBINE_DECODE
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class DispatchGmmCombineDecode : public OpDef
|
||||
{
|
||||
public:
|
||||
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm1_permuted_weight")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("gmm1_permuted_weight_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm2_weight")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("gmm2_weight_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_smooth_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("output")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("ep_recv_count")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("group_ep").String();
|
||||
this->Attr("ep_rank_size").Int();
|
||||
this->Attr("ep_rank_id").Int();
|
||||
this->Attr("moe_expert_num").Int();
|
||||
this->Attr("share_expert_num").Int();
|
||||
this->Attr("share_expert_rank_num").Int();
|
||||
this->Attr("quant_mode").Int();
|
||||
this->Attr("global_bs").Int();
|
||||
|
||||
this->MC2().HcclGroup({"group_ep"});
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DispatchGmmCombineDecode);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ge {
|
||||
constexpr uint32_t EXPAND_X_INDEX = 0;
|
||||
constexpr uint32_t EXPERT_IDS_INDEX = 1;
|
||||
constexpr uint32_t OUTPUT_X_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
|
||||
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
// infer output shape
|
||||
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
|
||||
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
|
||||
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
|
||||
recvCountOutShape == nullptr) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int bs = expertIdsShape->GetDim(0);
|
||||
int h = expandXShape->GetDim(1);
|
||||
|
||||
expandXOutShape->SetDimNum(expandXShape->GetDimNum());
|
||||
expandXOutShape->SetDim(0, bs);
|
||||
expandXOutShape->SetDim(1, h);
|
||||
|
||||
// infer recvCount shape
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
|
||||
|
||||
OPS_ERR_IF(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(epRankSizePtr == nullptr, OPS_LOG_E(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(sharedExpertRankNumPtr == nullptr, OPS_LOG_E(nodeName, "sharedExpertRankNumPtr is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
|
||||
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
|
||||
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
|
||||
|
||||
recvCountOutShape->SetDimNum(1);
|
||||
bool isShareExpert = (epRankId < sharedExpertRankNum);
|
||||
if (isShareExpert) {
|
||||
recvCountOutShape->SetDim(0, epRankSize);
|
||||
} else {
|
||||
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
|
||||
{
|
||||
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
|
||||
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
|
||||
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType);
|
||||
} // namespace ge
|
||||
@@ -0,0 +1,335 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t GM_ALIGN_SIZE = 512;
|
||||
constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2;
|
||||
constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024;
|
||||
constexpr uint32_t CUBE_WORKSPACE_STAGE = 4;
|
||||
constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024;
|
||||
|
||||
constexpr uint32_t INPUT_X_INDEX = 0;
|
||||
constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
|
||||
|
||||
constexpr uint32_t MIN_BATCH_SIZE = 1;
|
||||
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
||||
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
||||
constexpr uint32_t SUPPORT_TOP_K = 12;
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
||||
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
||||
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
|
||||
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static size_t CeilUp(size_t x, size_t y)
|
||||
{
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
|
||||
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
|
||||
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
|
||||
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
|
||||
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
|
||||
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
|
||||
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
OPS_ERR_IF(
|
||||
tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
|
||||
OPS_LOG_E(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
OPS_ERR_IF(
|
||||
gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
|
||||
OPS_LOG_E(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
OPS_ERR_IF(topK > SUPPORT_TOP_K, OPS_LOG_E(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
|
||||
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
if (globalBatchSize == 0) {
|
||||
globalBatchSize = epRankSize * batchSize;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize;
|
||||
} else {
|
||||
OPS_ERR_IF(globalBatchSize < 0, OPS_LOG_E(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(globalBatchSize % epRankSize > 0,
|
||||
OPS_LOG_E(nodeName, "globalBatchSize must be divisible by epRankSize."),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t recvAivNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aivNum / 2;
|
||||
OPS_ERR_IF(
|
||||
moeExpertNumPerRank > recvAivNum,
|
||||
OPS_LOG_E(nodeName, "moeExpertNumPerRank must <= (aivNum/2)(%u), but got %u", recvAivNum, moeExpertNumPerRank),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
|
||||
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
|
||||
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
|
||||
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
|
||||
|
||||
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
|
||||
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
|
||||
uint32_t sharedExpertNum = static_cast<uint32_t>(*sharedExpertNumPtr);
|
||||
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
|
||||
uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
|
||||
|
||||
OPS_ERR_IF(epRankId < 0, OPS_LOG_E(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(epRankId >= epRankSize, OPS_LOG_E(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OPS_LOG_E(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum <= 0, OPS_LOG_E(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(sharedExpertNum != 1, OPS_LOG_E(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
|
||||
OPS_LOG_E(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
groupEp = std::string(groupEpPtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str());
|
||||
uint32_t opType = OP_TYPE_ALL_TO_ALL;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
std::string algConfigAllGatherStr = "AllGather=level0:ring";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
size_t maxTokenNum;
|
||||
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
|
||||
uint32_t maxBatchSize = globalBs / epRankSize;
|
||||
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
|
||||
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
|
||||
if (epRankId < sharedExpertRankNum) {
|
||||
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
|
||||
} else {
|
||||
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
|
||||
}
|
||||
|
||||
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t CVSwapBufferSize =
|
||||
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
|
||||
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
|
||||
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
|
||||
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData<DispatchGmmCombineDecodeTilingData>();
|
||||
OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string groupEp = "";
|
||||
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX);
|
||||
OPS_ERR_IF(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "x shape dims must be 2, but current dim num is %lu.",
|
||||
xStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize;
|
||||
const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize;
|
||||
|
||||
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX);
|
||||
OPS_ERR_IF(expertIdsStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIds shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.",
|
||||
expertIdsStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
|
||||
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
|
||||
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
||||
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp);
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
|
||||
context->SetTilingKey(0);
|
||||
} else {
|
||||
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
|
||||
}
|
||||
context->SetBlockDim(aicNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct DispatchGmmCombineDecodeCompileInfo {};
|
||||
ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(DispatchGmmCombineDecode)
|
||||
.Tiling(DispatchGmmCombineDecodeTilingFunc)
|
||||
.TilingParse<DispatchGmmCombineDecodeCompileInfo>(TilingParseForDispatchGmmCombineDecode);
|
||||
} // namespace optiling
|
||||
Reference in New Issue
Block a user