[Kernel] Add moe normal ops (#4810)
### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.
- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.
2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.
- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
49
csrc/moe_dispatch_normal/op_host/CMakeLists.txt
Normal file
49
csrc/moe_dispatch_normal/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MoeDispatchNormal
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_moe_dispatch_normal.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
moe_dispatch_normal_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_dispatch_normal.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_moe_dispatch_normal.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerMoeDispatchNormalGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *topkIdx,
|
||||
const aclTensor *sendOffset,
|
||||
const aclTensor *sendTokenIdx,
|
||||
const aclTensor *recvOffset,
|
||||
const aclTensor *recvCount,
|
||||
char *groupEp,
|
||||
int64_t epWorldSize,
|
||||
int64_t epRankId,
|
||||
char *groupTpOptional,
|
||||
int64_t tpWorldSize,
|
||||
int64_t tpRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *recvX,
|
||||
const aclTensor *recvXScales,
|
||||
const aclTensor *assistInfoForCombine,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerMoeDispatchNormal(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x, const aclTensor *topkIdx,
|
||||
const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset, const aclTensor *recvCount,
|
||||
char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, int64_t tpWorldSize, int64_t tpRankId,
|
||||
int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, const aclTensor *recvX,
|
||||
const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerMoeDispatchNormalGetWorkspaceSize(x,
|
||||
topkIdx,
|
||||
sendOffset,
|
||||
sendTokenIdx,
|
||||
recvOffset,
|
||||
recvCount,
|
||||
groupEp,
|
||||
epWorldSize,
|
||||
epRankId,
|
||||
groupTpOptional,
|
||||
tpWorldSize,
|
||||
tpRankId,
|
||||
moeExpertNum,
|
||||
quantMode,
|
||||
globalBs,
|
||||
recvX,
|
||||
recvXScales,
|
||||
assistInfoForCombine,
|
||||
workspaceSize,
|
||||
executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMoeDispatchNormal(
|
||||
void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerMoeDispatchNormal(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
24
csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h
Normal file
24
csrc/moe_dispatch_normal/op_host/aclnn_moe_dispatch_normal.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef ACLNN_MOE_DISPATCH_NORMAL_H_
|
||||
#define ACLNN_MOE_DISPATCH_NORMAL_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormalGetWorkspaceSize(const aclTensor *x,
|
||||
const aclTensor *topkIdx, const aclTensor *sendOffset, const aclTensor *sendTokenIdx, const aclTensor *recvOffset,
|
||||
const aclTensor *recvCount, char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional,
|
||||
int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t quantMode, int64_t globalBs,
|
||||
const aclTensor *recvX, const aclTensor *recvXScales, const aclTensor *assistInfoForCombine,
|
||||
uint64_t *workspaceSize, aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMoeDispatchNormal(
|
||||
void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
92
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp
Normal file
92
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal.cpp
Normal file
@@ -0,0 +1,92 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class MoeDispatchNormal : public OpDef {
|
||||
public:
|
||||
explicit MoeDispatchNormal(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("topk_idx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Input("send_offset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("send_tokenIdx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("recv_offset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("recv_count")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
|
||||
this->Output("recv_x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_INT8, ge::DT_FLOAT16, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Output("x_scales")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Output("assist_info_for_combine")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("group_ep").AttrType(REQUIRED).String();
|
||||
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
|
||||
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();
|
||||
this->Attr("group_tp").AttrType(OPTIONAL).String("");
|
||||
this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("moe_expert_num").AttrType(REQUIRED).Int();
|
||||
this->Attr("quant_mode").AttrType(OPTIONAL).Int(0);
|
||||
this->Attr("global_bs").AttrType(OPTIONAL).Int(0);
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_true")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->MC2().HcclGroup({"group_ep", "group_tp"});
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(MoeDispatchNormal);
|
||||
|
||||
} // namespace ops
|
||||
635
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp
Normal file
635
csrc/moe_dispatch_normal/op_host/moe_dispatch_normal_tiling.cpp
Normal file
@@ -0,0 +1,635 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/moe_dispatch_normal_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace ge;
|
||||
namespace {
|
||||
class Mc2TilingUtils {
|
||||
public:
|
||||
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
|
||||
static uint64_t GetMaxWindowSize()
|
||||
{
|
||||
uint16_t defaultWindowSize = 200;
|
||||
if (getenv(HCCL_BUFFSIZE) == nullptr) {
|
||||
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
|
||||
} else {
|
||||
try {
|
||||
std::string envStr(getenv(HCCL_BUFFSIZE));
|
||||
defaultWindowSize = std::stoi(envStr);
|
||||
} catch (const std::invalid_argument &ia) {
|
||||
OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
|
||||
} catch (const std::out_of_range &oor) {
|
||||
OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
|
||||
}
|
||||
}
|
||||
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
|
||||
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
|
||||
return maxWindowSize;
|
||||
}
|
||||
};
|
||||
constexpr uint32_t X_INDEX = 0U;
|
||||
constexpr uint32_t EXPERT_IDS_INDEX = 1U;
|
||||
constexpr uint32_t SEND_OFFSET_INDEX = 2U;
|
||||
constexpr uint32_t SEND_TOKENIDX_INDEX = 3U;
|
||||
constexpr uint32_t RECV_OFFSET_INDEX = 4U;
|
||||
constexpr uint32_t RECV_COUNT_INDEX = 5U;
|
||||
|
||||
constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0U;
|
||||
constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1U;
|
||||
constexpr uint32_t OUTPUT_ASSIST_INFO_INDEX = 2U;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_GROUP_TP_INDEX = 3;
|
||||
constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 4;
|
||||
constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 5;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 6;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 7;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 8;
|
||||
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t ONE_DIM = 1;
|
||||
constexpr uint32_t DYNAMIC_SCALE_DIM_NUM = 1;
|
||||
constexpr uint64_t INIT_TILINGKEY = 10000;
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
|
||||
constexpr uint32_t NO_SCALES = 0;
|
||||
constexpr uint32_t DYNAMIC_SCALES = 2;
|
||||
constexpr uint32_t OP_TYPE_ALL_GATHER = 6;
|
||||
|
||||
constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
|
||||
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
|
||||
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
|
||||
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
|
||||
constexpr int64_t BS_UPPER_BOUND = 8000; // Maximum bs
|
||||
|
||||
constexpr uint32_t TILINGKEY_TP_WORLD_SIZE = 100;
|
||||
constexpr uint32_t TP_WORLD_SIZE_TWO = 2;
|
||||
constexpr int64_t MOE_EXPERT_MAX_NUM = 512;
|
||||
constexpr int64_t K_MAX = 16;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t WORKSPACE_ELEMENT_OFFSET = 512;
|
||||
constexpr int64_t H_MIN = 1024;
|
||||
constexpr int64_t H_MAX = 7168;
|
||||
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
|
||||
constexpr uint64_t TRIPLE = 3;
|
||||
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
|
||||
constexpr uint64_t SCALE_EXPAND_IDX_BUFFER = 44UL; // scale32B + 3*4expandIdx
|
||||
constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
|
||||
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
|
||||
constexpr uint64_t UB_ALIGN = 32UL;
|
||||
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, MoeDispatchNormalTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "epWorldSize is %u.", tilingData.moeDispatchNormalInfo.epWorldSize);
|
||||
OPS_LOG_D(nodeName, "tpWorldSize is %u.", tilingData.moeDispatchNormalInfo.tpWorldSize);
|
||||
OPS_LOG_D(nodeName, "epRankId is %u.", tilingData.moeDispatchNormalInfo.epRankId);
|
||||
OPS_LOG_D(nodeName, "tpRankId is %u.", tilingData.moeDispatchNormalInfo.tpRankId);
|
||||
OPS_LOG_D(nodeName, "moeExpertNum is %u.", tilingData.moeDispatchNormalInfo.moeExpertNum);
|
||||
OPS_LOG_D(nodeName, "quantMode is %u.", tilingData.moeDispatchNormalInfo.quantMode);
|
||||
OPS_LOG_D(nodeName, "globalBs is %u.", tilingData.moeDispatchNormalInfo.globalBs);
|
||||
OPS_LOG_D(nodeName, "bs is %u.", tilingData.moeDispatchNormalInfo.bs);
|
||||
OPS_LOG_D(nodeName, "k is %u.", tilingData.moeDispatchNormalInfo.k);
|
||||
OPS_LOG_D(nodeName, "h is %u.", tilingData.moeDispatchNormalInfo.h);
|
||||
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.moeDispatchNormalInfo.aivNum);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.moeDispatchNormalInfo.totalUbSize);
|
||||
OPS_LOG_D(nodeName, "totalWinSize is %lu.", tilingData.moeDispatchNormalInfo.totalWinSize);
|
||||
}
|
||||
|
||||
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
OPS_CHECK(xStorageShape == nullptr, OPS_LOG_E(nodeName, "xShape is null."), return false);
|
||||
OPS_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"xShape dims must be 2, but current dim num is %lu.",
|
||||
xStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_LOG_D(nodeName, "x dim0 = %ld", xDim0);
|
||||
OPS_LOG_D(nodeName, "x dim1 = %ld", xDim1);
|
||||
|
||||
const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIdShape is null."), return false);
|
||||
OPS_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expertIdShape dims must be 2, but current dim num is %lu.",
|
||||
expertIdStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXStorageShape == nullptr, OPS_LOG_E(nodeName, "expandXShape is null."), return false);
|
||||
OPS_CHECK(expandXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expandXShape dims must be 2, but current dim num is %lu.",
|
||||
expandXStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "expandX dim0 = %ld", expandXStorageShape->GetStorageShape().GetDim(0));
|
||||
OPS_LOG_D(nodeName, "expandX dim1 = %ld", expandXStorageShape->GetStorageShape().GetDim(1));
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(
|
||||
dynamicScalesStorageShape == nullptr, OPS_LOG_E(nodeName, "dynamicScalesShape is null."), return false);
|
||||
OPS_CHECK(dynamicScalesStorageShape->GetStorageShape().GetDimNum() != DYNAMIC_SCALE_DIM_NUM,
|
||||
OPS_LOG_E(nodeName,
|
||||
"dynamicScalesShape dims must be %u, but current dim num is %lu.",
|
||||
DYNAMIC_SCALE_DIM_NUM,
|
||||
dynamicScalesStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "dynamicScales dim0 = %ld", dynamicScalesStorageShape->GetStorageShape().GetDim(0));
|
||||
}
|
||||
|
||||
const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoStorageShape == nullptr, OPS_LOG_E(nodeName, "assistInfoShape is null."), return false);
|
||||
OPS_CHECK(assistInfoStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
|
||||
OPS_LOG_E(nodeName,
|
||||
"assistInfoShape dims must be 1, but current dim num is %lu.",
|
||||
assistInfoStorageShape->GetStorageShape().GetDimNum()),
|
||||
return false);
|
||||
OPS_LOG_D(nodeName, "assistInfoForCombine dim0 = %ld", assistInfoStorageShape->GetStorageShape().GetDim(0));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK((xDesc->GetDataType() != ge::DT_BF16) && (xDesc->GetDataType() != ge::DT_FLOAT16),
|
||||
OPS_LOG_E(nodeName, "x dataType is invalid, dataType should be bf16 or float16, but is ."),
|
||||
return false);
|
||||
|
||||
auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false);
|
||||
OPS_CHECK(expertIdDesc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(nodeName, "expertId dataType is invalid, dataType should be int32, but is ."),
|
||||
return false);
|
||||
|
||||
auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false);
|
||||
if (quantMode != NO_SCALES) {
|
||||
OPS_CHECK(expandXDesc->GetDataType() != ge::DT_INT8,
|
||||
OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be int8, but is."),
|
||||
return false);
|
||||
} else {
|
||||
OPS_CHECK(expandXDesc->GetDataType() != xDesc->GetDataType(),
|
||||
OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be equal to x dataType , but is."),
|
||||
return false);
|
||||
}
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false);
|
||||
OPS_CHECK(dynamicScalesDesc->GetDataType() != ge::DT_FLOAT,
|
||||
OPS_LOG_E(nodeName, "dynamicScales dataType is invalid, dataType should be float, but is ."),
|
||||
return false);
|
||||
}
|
||||
|
||||
auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false);
|
||||
OPS_CHECK(assistInfoDesc->GetDataType() != ge::DT_INT32,
|
||||
OPS_LOG_E(nodeName, "assistInfoForCombine dataType is invalid, dataType should be int32, but is ."),
|
||||
return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||
OPS_CHECK(xDesc == nullptr, OPS_LOG_E(nodeName, "xDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "x format is invalid."),
|
||||
return false);
|
||||
|
||||
auto expertIdDesc = context->GetInputDesc(EXPERT_IDS_INDEX);
|
||||
OPS_CHECK(expertIdDesc == nullptr, OPS_LOG_E(nodeName, "expertIdDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(expertIdDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "expertId format is invalid."),
|
||||
return false);
|
||||
|
||||
auto expandXDesc = context->GetOutputDesc(OUTPUT_EXPAND_X_INDEX);
|
||||
OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandXDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(expandXDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "expandX format is invalid."),
|
||||
return false);
|
||||
|
||||
if (quantMode == DYNAMIC_SCALES) {
|
||||
auto dynamicScalesDesc = context->GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
OPS_CHECK(dynamicScalesDesc == nullptr, OPS_LOG_E(nodeName, "dynamicScalesDesc is null."), return false);
|
||||
OPS_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(dynamicScalesDesc->GetStorageFormat())) ==
|
||||
ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "dynamicScales format is invalid."),
|
||||
return false);
|
||||
}
|
||||
|
||||
auto assistInfoDesc = context->GetOutputDesc(OUTPUT_ASSIST_INFO_INDEX);
|
||||
OPS_CHECK(assistInfoDesc == nullptr, OPS_LOG_E(nodeName, "assistInfoDesc is null."), return false);
|
||||
OPS_CHECK(
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(assistInfoDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
|
||||
OPS_LOG_E(nodeName, "assistInfoForCombine format is invalid."),
|
||||
return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
MoeDispatchNormalTilingData &tilingData, std::string &groupEp, std::string &groupTp)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
|
||||
auto groupTpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_TP_INDEX));
|
||||
auto epWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_WORLD_SIZE_INDEX);
|
||||
auto tpWorldSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_WORLD_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto tpRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_TP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
|
||||
|
||||
// Check for null
|
||||
OPS_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "groupEpPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "epWorldSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(nodeName, "tpWorldSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(nodeName, "tpRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(quantModePtr == nullptr, OPS_LOG_E(nodeName, "quantModePtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if it meets uint32_t and other constraints
|
||||
int64_t moeExpertNum = *moeExpertNumPtr;
|
||||
int64_t epWorldSize = *epWorldSizePtr;
|
||||
OPS_CHECK((epWorldSize < MIN_EP_WORLD_SIZE) || (epWorldSize > MAX_EP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"epWorldSize is invalid, only support [%ld, %ld], but got epWorldSize=%ld.",
|
||||
MIN_EP_WORLD_SIZE,
|
||||
MAX_EP_WORLD_SIZE,
|
||||
epWorldSize),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.",
|
||||
MAX_TP_WORLD_SIZE,
|
||||
*tpWorldSizePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= epWorldSize),
|
||||
OPS_LOG_E(
|
||||
nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.", epWorldSize, *epRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
if (*tpWorldSizePtr > 1) {
|
||||
OPS_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.",
|
||||
*tpWorldSizePtr,
|
||||
*tpRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "groupTpPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
groupTp = std::string(groupTpPtr);
|
||||
} else {
|
||||
OPS_CHECK(*tpRankIdPtr != 0,
|
||||
OPS_LOG_E(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
OPS_CHECK((moeExpertNum <= 0) || (moeExpertNum > MOE_EXPERT_MAX_NUM),
|
||||
OPS_LOG_E(nodeName,
|
||||
"moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.",
|
||||
MOE_EXPERT_MAX_NUM,
|
||||
moeExpertNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(
|
||||
(*quantModePtr < static_cast<int64_t>(NO_SCALES)) || (*quantModePtr > static_cast<int64_t>(DYNAMIC_SCALES)),
|
||||
OPS_LOG_E(nodeName,
|
||||
"quantMode is invalid, only support [0, %u], but got quantMode=%ld.",
|
||||
DYNAMIC_SCALES,
|
||||
*quantModePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
int64_t moePerRankNum = moeExpertNum / epWorldSize;
|
||||
int64_t curDispatchStatusNum = moePerRankNum * epWorldSize;
|
||||
OPS_CHECK((curDispatchStatusNum > DISPATCH_STATUS_MAX_SUPPORT_NUM),
|
||||
OPS_LOG_E(nodeName,
|
||||
"The moe experts num must meet the conditions,"
|
||||
" (moeExpertNum / epWorldSize * epWorldSize <= 1280, but cur is %ld.",
|
||||
curDispatchStatusNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
groupEp = std::string(groupEpPtr);
|
||||
tilingData.moeDispatchNormalInfo.epWorldSize = static_cast<uint32_t>(epWorldSize);
|
||||
tilingData.moeDispatchNormalInfo.tpWorldSize = static_cast<uint32_t>(*tpWorldSizePtr);
|
||||
tilingData.moeDispatchNormalInfo.epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
tilingData.moeDispatchNormalInfo.tpRankId = static_cast<uint32_t>(*tpRankIdPtr);
|
||||
tilingData.moeDispatchNormalInfo.moeExpertNum = static_cast<uint32_t>(moeExpertNum);
|
||||
tilingData.moeDispatchNormalInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckAttrs(
|
||||
gert::TilingContext *context, const char *nodeName, MoeDispatchNormalTilingData &tilingData, uint32_t &localMoeExpertNum)
|
||||
{
|
||||
uint32_t epWorldSize = tilingData.moeDispatchNormalInfo.epWorldSize;
|
||||
uint32_t tpWorldSize = tilingData.moeDispatchNormalInfo.tpWorldSize;
|
||||
uint32_t moeExpertNum = tilingData.moeDispatchNormalInfo.moeExpertNum;
|
||||
|
||||
// Validate if moe expert number can be evenly distributed across multiple machines
|
||||
localMoeExpertNum = moeExpertNum / epWorldSize;
|
||||
OPS_CHECK(moeExpertNum % epWorldSize != 0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"moeExpertNum should be divisible by epWorldSize, "
|
||||
"but moeExpertNum=%u, epWorldSize=%u.",
|
||||
moeExpertNum,
|
||||
epWorldSize),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(localMoeExpertNum <= 0,
|
||||
OPS_LOG_E(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %d", localMoeExpertNum),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate input x dimension 0 and set bs
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0),
|
||||
OPS_LOG_E(
|
||||
nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.", BS_UPPER_BOUND, xDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData.moeDispatchNormalInfo.bs = static_cast<uint32_t>(xDim0);
|
||||
|
||||
// Validate globalBS
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
|
||||
OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(nodeName, "globalBsPtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_LOG_D(nodeName, "MoeDispatchNormal *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n", *globalBsPtr, xDim0, epWorldSize);
|
||||
OPS_CHECK(*globalBsPtr <= 0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"globalBS is invalid, should be positive, but got globalBS=%ld.",
|
||||
*globalBsPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tilingData.moeDispatchNormalInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
|
||||
MoeDispatchNormalTilingData &tilingData, const uint32_t quantMode, const int64_t localMoeExpertNum)
|
||||
{
|
||||
uint32_t A = 0U;
|
||||
uint32_t globalBs = tilingData.moeDispatchNormalInfo.globalBs;
|
||||
|
||||
// Validate input x dimension 1 and set h, bs already validated
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
|
||||
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK((xDim1 < H_MIN) || (xDim1 > H_MAX),
|
||||
OPS_LOG_E(nodeName, "xShape dims1(H) should be in [%ld, %ld], but got %ld.", H_MIN, H_MAX, xDim1),
|
||||
return ge::GRAPH_FAILED); // 32-byte aligned
|
||||
tilingData.moeDispatchNormalInfo.h = static_cast<uint32_t>(xDim1);
|
||||
|
||||
// Validate expert_id dimensions and set k
|
||||
int64_t moeExpertNum = static_cast<int64_t>(tilingData.moeDispatchNormalInfo.moeExpertNum);
|
||||
const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
const int64_t expertIdsDim0 = expertIdStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t expertIdsDim1 = expertIdStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_CHECK(xDim0 != expertIdsDim0,
|
||||
OPS_LOG_E(nodeName,
|
||||
"xShape's dim0 not equal to expertIdShape's dim0, "
|
||||
"xShape's dim0 is %ld, expertIdShape's dim0 is %ld.",
|
||||
xDim0,
|
||||
expertIdsDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((expertIdsDim1 <= 0) || (expertIdsDim1 > K_MAX) || (expertIdsDim1 > moeExpertNum),
|
||||
OPS_LOG_E(nodeName,
|
||||
"expertIdShape's dim1(k) should be in (0, min(%ld, moeExpertNum=%ld)], "
|
||||
"but got expertIdShape's dim1=%ld.",
|
||||
K_MAX,
|
||||
moeExpertNum,
|
||||
expertIdsDim1),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData.moeDispatchNormalInfo.k = static_cast<uint32_t>(expertIdsDim1);
|
||||
|
||||
A = globalBs;
|
||||
|
||||
// Validate expandX dimensions
|
||||
const gert::StorageShape *expandXStorageShape = context->GetOutputShape(OUTPUT_EXPAND_X_INDEX);
|
||||
const int64_t expandXDim0 = expandXStorageShape->GetStorageShape().GetDim(0);
|
||||
const int64_t expandXDim1 = expandXStorageShape->GetStorageShape().GetDim(1);
|
||||
|
||||
OPS_CHECK(xDim1 != expandXDim1,
|
||||
OPS_LOG_E(nodeName,
|
||||
"expandX's dim1 not equal to xShape's dim1, "
|
||||
"xShape's dim1 is %ld, expandX's dim1 is %ld.",
|
||||
xDim1,
|
||||
expandXDim1),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate dynamicScales dimensions
|
||||
if (quantMode != NO_SCALES) {
|
||||
const gert::StorageShape *dynamicScalesStorageShape = context->GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX);
|
||||
const int64_t dynamicScalesDim0 = dynamicScalesStorageShape->GetStorageShape().GetDim(0);
|
||||
}
|
||||
|
||||
// Validate assistInfo dimensions
|
||||
const gert::StorageShape *assistInfoStorageShape = context->GetOutputShape(OUTPUT_ASSIST_INFO_INDEX);
|
||||
const int64_t assistInfoDim0 = assistInfoStorageShape->GetStorageShape().GetDim(0);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckMoeDispatchNormal(
|
||||
gert::TilingContext *context, const char *nodeName, const uint32_t quantMode)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDim(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params shape is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(!CheckTensorFormat(context, nodeName, quantMode),
|
||||
OPS_LOG_E(nodeName, "params format is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void CalTilingKey(uint64_t &tilingKey, const uint32_t quantMode, const uint32_t tpWorldSize)
|
||||
{
|
||||
tilingKey += static_cast<uint64_t>(quantMode);
|
||||
if (tpWorldSize == TP_WORLD_SIZE_TWO) {
|
||||
tilingKey += static_cast<uint64_t>(TILINGKEY_TP_WORLD_SIZE);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context, MoeDispatchNormalTilingData *tiling, const std::string groupEp,
|
||||
const std::string groupTp)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "MoeDispatchNormal groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str());
|
||||
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
|
||||
uint32_t opType2 = OP_TYPE_ALL_GATHER;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
std::string algConfigAllGatherStr = "AllGather=level0:ring";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
|
||||
|
||||
mc2CcTilingConfig.SetGroupName(groupTp);
|
||||
mc2CcTilingConfig.SetOpType(opType2);
|
||||
mc2CcTilingConfig.SetAlgConfig(algConfigAllGatherStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling2);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
workSpaces[0] = static_cast<uint64_t>(SYSTEM_NEED_WORKSPACE + WORKSPACE_ELEMENT_OFFSET * aivNum * aivNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeDispatchNormalA3TilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
MoeDispatchNormalTilingData *tilingData = context->GetTilingData<MoeDispatchNormalTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string groupEp = "";
|
||||
std::string groupTp = "";
|
||||
uint32_t quantMode = NO_SCALES;
|
||||
uint32_t localMoeExpertNum = 1;
|
||||
OPS_LOG_I(nodeName, "Enter MoeDispatchNormal tiling check func.");
|
||||
|
||||
// Get input parameter attributes
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp, groupTp) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
quantMode = tilingData->moeDispatchNormalInfo.quantMode;
|
||||
|
||||
// Check input/output dim, format, dataType
|
||||
OPS_CHECK(TilingCheckMoeDispatchNormal(context, nodeName, quantMode) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Check if attribute values are valid
|
||||
OPS_CHECK(CheckAttrs(context, nodeName, *tilingData, localMoeExpertNum) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Check attr failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t epRankId = tilingData->moeDispatchNormalInfo.epRankId;
|
||||
|
||||
// Check shape dimensions and assign h, k
|
||||
OPS_CHECK(
|
||||
CheckTensorShape(context, nodeName, *tilingData, quantMode, static_cast<int64_t>(localMoeExpertNum)) !=
|
||||
ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Check tensor shape failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
// Validate win area size
|
||||
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
|
||||
uint64_t h = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.h);
|
||||
uint64_t k = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.k);
|
||||
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.epWorldSize);
|
||||
uint64_t maxBs = static_cast<uint64_t>(tilingData->moeDispatchNormalInfo.globalBs) / epWorldSize;
|
||||
|
||||
// Dispatch data area: token start aligned to 512, valid token length h_align_32b + scale(32b) + triplet(3*4b)
|
||||
uint64_t tokenActualLen =
|
||||
((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_EXPAND_IDX_BUFFER;
|
||||
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
// Not considering dual stream size
|
||||
uint64_t actualSize = maxBs * k * tokenNeedSizeDispatch * DOUBLE_DATA_BUFFER;
|
||||
OPS_CHECK((actualSize > maxWindowSize),
|
||||
OPS_LOG_E(nodeName,
|
||||
"HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu,"
|
||||
" localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu,"
|
||||
" k = %lu, NEEDED_HCCL_BUFFSIZE(maxBs * k * tokenNeedSizeDispatch) = %luMB,"
|
||||
" HCCL_BUFFSIZE=%luMB.",
|
||||
maxBs,
|
||||
h,
|
||||
epWorldSize,
|
||||
localMoeExpertNum,
|
||||
tokenNeedSizeDispatch,
|
||||
k,
|
||||
actualSize / MB_SIZE + 1UL,
|
||||
maxWindowSize / MB_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData->moeDispatchNormalInfo.totalWinSize = maxWindowSize;
|
||||
OPS_LOG_D(nodeName, "windowSize = %lu", maxWindowSize);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp, groupTp);
|
||||
uint32_t tpWorldSize = tilingData->moeDispatchNormalInfo.tpWorldSize;
|
||||
uint64_t tilingKey = INIT_TILINGKEY;
|
||||
CalTilingKey(tilingKey, quantMode, tpWorldSize);
|
||||
OPS_LOG_D(nodeName, "tilingKey is %lu", tilingKey);
|
||||
context->SetTilingKey(tilingKey);
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
|
||||
context->SetBlockDim(blockDim);
|
||||
context->SetScheduleMode(1); // Set to batch mode, all cores start simultaneously
|
||||
tilingData->moeDispatchNormalInfo.totalUbSize = ubSize;
|
||||
tilingData->moeDispatchNormalInfo.aivNum = aivNum;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MoeDispatchNormalTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = MoeDispatchNormalA3TilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct MoeDispatchNormalCompileInfo {};
|
||||
ge::graphStatus TilingParseForMoeDispatchNormal(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(MoeDispatchNormal)
|
||||
.Tiling(MoeDispatchNormalTilingFunc)
|
||||
.TilingParse<MoeDispatchNormalCompileInfo>(TilingParseForMoeDispatchNormal);
|
||||
} // namespace optiling
|
||||
56
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp
Normal file
56
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "moe_dispatch_normal_tiling.h"
|
||||
#include "moe_dispatch_normal.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace MoeDispatchNormalImpl;
|
||||
|
||||
#define TILINGKEY_NO_QUANT 10000
|
||||
#define TILINGKEY_QUANT 10002
|
||||
|
||||
extern "C" __global__ __aicore__ void moe_dispatch_normal(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset,
|
||||
GM_ADDR send_token_idx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut,
|
||||
GM_ADDR assist_info_for_combine, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(MoeDispatchNormalTilingData);
|
||||
TPipe pipe;
|
||||
#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
|
||||
if (TILING_KEY_IS(TILINGKEY_NO_QUANT)) {
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM);
|
||||
MoeDispatchNormal<DTYPE_X, DTYPE_RECV_X, false, false, false> op;
|
||||
op.Init(x,
|
||||
expertIds,
|
||||
send_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
expandXOut,
|
||||
dynamicScalesOut,
|
||||
assist_info_for_combine,
|
||||
workspaceGM,
|
||||
&pipe,
|
||||
&tilingData);
|
||||
op.Process();
|
||||
return;
|
||||
}
|
||||
#elif (ORIG_DTYPE_RECV_X == DT_INT8)
|
||||
if (TILING_KEY_IS(TILINGKEY_QUANT)) {
|
||||
GET_TILING_DATA_WITH_STRUCT(MoeDispatchNormalTilingData, tilingData, tilingGM);
|
||||
MoeDispatchNormal<DTYPE_X, DTYPE_RECV_X, true, false, false> op;
|
||||
op.Init(x,
|
||||
expertIds,
|
||||
send_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
expandXOut,
|
||||
dynamicScalesOut,
|
||||
assist_info_for_combine,
|
||||
workspaceGM,
|
||||
&pipe,
|
||||
&tilingData);
|
||||
op.Process();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
540
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h
Normal file
540
csrc/moe_dispatch_normal/op_kernel/moe_dispatch_normal.h
Normal file
@@ -0,0 +1,540 @@
|
||||
#ifndef MOE_DISPATCH_NORMAL_H
|
||||
#define MOE_DISPATCH_NORMAL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_dispatch_normal_tiling.h"
|
||||
|
||||
namespace MoeDispatchNormalImpl {
|
||||
constexpr uint8_t BUFFER_NUM = 2;
|
||||
constexpr uint32_t STATE_OFFSET = 32U;
|
||||
constexpr uint32_t UB_ALIGN = 32U;
|
||||
constexpr uint8_t COMM_NUM = 2;
|
||||
constexpr uint8_t COMM_EP_IDX = 0;
|
||||
constexpr uint8_t COMM_TP_IDX = 1;
|
||||
|
||||
constexpr uint64_t WIN_STATE_OFFSET = 500UL * 1024UL;
|
||||
constexpr uint64_t STATE_WIN_OFFSET = 950UL * 1024UL;
|
||||
constexpr uint64_t WIN_ADDR_ALIGN = 512UL;
|
||||
constexpr uint32_t EXPAND_IDX_INFO = 3U;
|
||||
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL;
|
||||
|
||||
template <AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
#define CamTypeClass \
|
||||
typename XType, typename ExpandXOutType, bool DynamicQuant, bool IsSmoothScaleExist, bool IsShareExpertRank
|
||||
|
||||
#define CamTypeFunc XType, ExpandXOutType, DynamicQuant, IsSmoothScaleExist, IsShareExpertRank
|
||||
|
||||
using namespace AscendC;
|
||||
template <CamTypeClass>
|
||||
class MoeDispatchNormal {
|
||||
public:
|
||||
__aicore__ inline MoeDispatchNormal(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx,
|
||||
GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
||||
GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void InputToShare();
|
||||
__aicore__ inline void SetStatus();
|
||||
__aicore__ inline void WaitStatus();
|
||||
__aicore__ inline void ShareToOutput();
|
||||
__aicore__ inline void UpdateOutput();
|
||||
__aicore__ inline void FillTriple(LocalTensor<ExpandXOutType> &xOutTensor, uint32_t tokenIndex, uint32_t k);
|
||||
__aicore__ inline void QuantInit();
|
||||
__aicore__ inline void ReduceMaxInplace(const LocalTensor<float> &srcLocal, uint32_t count);
|
||||
__aicore__ inline void QuantProcess();
|
||||
__aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId)
|
||||
{
|
||||
uint32_t curRankId = ((ctxIdx == COMM_EP_IDX) ? epRankId : tpRankId);
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset + COMBINE_STATE_WIN_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) +
|
||||
winDataSizeOffset + COMBINE_STATE_WIN_OFFSET;
|
||||
}
|
||||
|
||||
__aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId)
|
||||
{
|
||||
uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId : tpRankId;
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState * WIN_STATE_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsExp) +
|
||||
dataState * WIN_STATE_OFFSET;
|
||||
}
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
GlobalTensor<XType> xGT;
|
||||
GlobalTensor<int32_t> expertIdsGT;
|
||||
GlobalTensor<int32_t> sendOffsetGT;
|
||||
GlobalTensor<int32_t> sendTokenIdxGT;
|
||||
GlobalTensor<int32_t> recvOffsetGT;
|
||||
GlobalTensor<int32_t> recvCountGT;
|
||||
GlobalTensor<float> dynamicScalesOutGT;
|
||||
GlobalTensor<int32_t> expandIdxOutGT;
|
||||
GlobalTensor<ExpandXOutType> dstGT;
|
||||
GlobalTensor<int32_t> dstStatusGT;
|
||||
|
||||
LocalTensor<XType> xInTensor;
|
||||
LocalTensor<ExpandXOutType> xOutTensor;
|
||||
LocalTensor<ExpandXOutType> xTmpTensor;
|
||||
LocalTensor<int32_t> expertIdsTensor;
|
||||
LocalTensor<int32_t> sendOffsetTensor;
|
||||
LocalTensor<int32_t> sendTokenIdxTensor;
|
||||
LocalTensor<int32_t> recvOffsetTensor;
|
||||
LocalTensor<int32_t> recvCountTensor;
|
||||
LocalTensor<int32_t> statusTensor;
|
||||
|
||||
TBuf<> expertIdsBuf;
|
||||
TBuf<> sendOffsetBuf;
|
||||
TBuf<> sendTokenIdxBuf;
|
||||
TBuf<> recvOffsetBuf;
|
||||
TBuf<> recvCountBuf;
|
||||
TBuf<> statusBuf;
|
||||
TBuf<> waitStatusBuf;
|
||||
TBuf<> gatherMaskOutBuf;
|
||||
TBuf<> scalarBuf;
|
||||
TBuf<> tokenCastFloatBuf;
|
||||
TBuf<> tokenAbsFloatBuf;
|
||||
|
||||
GM_ADDR expandXOutGM;
|
||||
GM_ADDR shareGM;
|
||||
|
||||
uint32_t batchSize{0};
|
||||
uint32_t globalBatchSize{0};
|
||||
uint32_t h{0};
|
||||
uint32_t topK{0};
|
||||
uint32_t blockNum{0};
|
||||
uint32_t blockIdx{0};
|
||||
uint32_t epRankSize{0};
|
||||
uint32_t epRankId{0};
|
||||
uint32_t tpRankSize{0};
|
||||
uint32_t tpRankId{0};
|
||||
uint32_t moeExpertNum{0};
|
||||
uint32_t moeExpertNumPerRank{0};
|
||||
|
||||
uint32_t hUBAlignSize{0};
|
||||
uint32_t hOutGMAlignSize{0};
|
||||
uint32_t hOutUBAlignSize{0};
|
||||
uint32_t hGMAlignCnt{0};
|
||||
uint32_t expandIdxStartIdx{0};
|
||||
uint32_t expertIdsCnt{0};
|
||||
uint32_t stateOffset{0};
|
||||
uint32_t dataState{0};
|
||||
uint32_t winDataSizeOffset{0};
|
||||
|
||||
uint32_t startStatusId;
|
||||
uint32_t endStatusId;
|
||||
uint32_t statusNumPerCore;
|
||||
uint32_t remainStatus;
|
||||
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> xOutQueue;
|
||||
|
||||
__gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr};
|
||||
|
||||
DataCopyExtParams hCommuCopyOutParams;
|
||||
};
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset,
|
||||
GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut,
|
||||
GM_ADDR expandIdxOut, GM_ADDR workspaceGM, TPipe *pipe, const MoeDispatchNormalTilingData *tilingData)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
blockIdx = GetBlockIdx();
|
||||
|
||||
winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>();
|
||||
|
||||
GlobalTensor<int32_t> selfDataStatusTensor;
|
||||
GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp);
|
||||
selfDataStatusTensor.SetGlobalBuffer(
|
||||
(__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET + blockIdx * WIN_ADDR_ALIGN));
|
||||
|
||||
batchSize = tilingData->moeDispatchNormalInfo.bs;
|
||||
globalBatchSize = tilingData->moeDispatchNormalInfo.globalBs;
|
||||
h = tilingData->moeDispatchNormalInfo.h;
|
||||
topK = tilingData->moeDispatchNormalInfo.k;
|
||||
blockNum = tilingData->moeDispatchNormalInfo.aivNum;
|
||||
epRankSize = tilingData->moeDispatchNormalInfo.epWorldSize;
|
||||
epRankId = tilingData->moeDispatchNormalInfo.epRankId;
|
||||
moeExpertNum = tilingData->moeDispatchNormalInfo.moeExpertNum;
|
||||
moeExpertNumPerRank = moeExpertNum / epRankSize;
|
||||
|
||||
xGT.SetGlobalBuffer((__gm__ XType *)x);
|
||||
expertIdsGT.SetGlobalBuffer((__gm__ int32_t *)expertIds);
|
||||
sendOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(send_offset));
|
||||
sendTokenIdxGT.SetGlobalBuffer((__gm__ int32_t *)(send_tokenIdx));
|
||||
recvOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(recv_offset));
|
||||
recvCountGT.SetGlobalBuffer((__gm__ int32_t *)(recv_count));
|
||||
dynamicScalesOutGT.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
|
||||
expandIdxOutGT.SetGlobalBuffer((__gm__ int32_t *)(expandIdxOut));
|
||||
|
||||
expandXOutGM = expandXOut;
|
||||
|
||||
hUBAlignSize = Ceil(h * sizeof(ExpandXOutType), UB_ALIGN) * UB_ALIGN;
|
||||
uint32_t hScaleSizeAlign = hUBAlignSize + UB_ALIGN;
|
||||
expandIdxStartIdx = hScaleSizeAlign / sizeof(int32_t);
|
||||
|
||||
uint32_t hScaleIdxSize = hScaleSizeAlign + EXPAND_IDX_INFO * sizeof(int32_t);
|
||||
hOutGMAlignSize = Ceil(hScaleIdxSize, WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
hGMAlignCnt = hOutGMAlignSize / sizeof(ExpandXOutType);
|
||||
|
||||
expertIdsCnt = batchSize * topK;
|
||||
statusNumPerCore = moeExpertNum / blockNum;
|
||||
remainStatus = moeExpertNum % blockNum;
|
||||
startStatusId = statusNumPerCore * blockIdx;
|
||||
if (blockIdx < remainStatus) {
|
||||
statusNumPerCore += 1;
|
||||
startStatusId += blockIdx;
|
||||
} else {
|
||||
startStatusId += remainStatus;
|
||||
}
|
||||
endStatusId = startStatusId + statusNumPerCore;
|
||||
stateOffset = STATE_OFFSET;
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfDataStatusTensor);
|
||||
dataState = selfDataStatusTensor(0);
|
||||
if (dataState == 0) {
|
||||
selfDataStatusTensor(0) = 1;
|
||||
} else {
|
||||
selfDataStatusTensor(0) = 0;
|
||||
}
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(selfDataStatusTensor);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
uint64_t hSizeAlignCombine = Ceil(h * sizeof(XType), WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
|
||||
winDataSizeOffset = dataState * (tilingData->moeDispatchNormalInfo.totalWinSize / 2) +
|
||||
globalBatchSize / epRankSize * topK * hSizeAlignCombine;
|
||||
shareGM = GetWindAddrByRankId(COMM_EP_IDX, epRankId);
|
||||
|
||||
hOutUBAlignSize = Ceil(hScaleIdxSize, UB_ALIGN) * UB_ALIGN;
|
||||
if constexpr (DynamicQuant) {
|
||||
QuantInit();
|
||||
} else {
|
||||
tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 2 * 14K = 28K
|
||||
}
|
||||
|
||||
tpipe_->InitBuffer(sendOffsetBuf, moeExpertNum * sizeof(int32_t)); // 4 * moeNum
|
||||
sendOffsetTensor = sendOffsetBuf.Get<int32_t>();
|
||||
|
||||
hCommuCopyOutParams = {1U, static_cast<uint32_t>(hScaleIdxSize), 0U, 0U, 0U};
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::QuantInit()
|
||||
{
|
||||
uint32_t hAlignSize = Ceil(h * sizeof(XType), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(xInQueue, BUFFER_NUM, hAlignSize); // 14K * 2
|
||||
tpipe_->InitBuffer(xOutQueue, BUFFER_NUM, hOutUBAlignSize); // 7K * 2
|
||||
|
||||
tpipe_->InitBuffer(tokenCastFloatBuf, h * sizeof(float)); // 28K
|
||||
tpipe_->InitBuffer(tokenAbsFloatBuf, h * sizeof(float)); // 28K
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::ReduceMaxInplace(
|
||||
const LocalTensor<float> &srcLocal, uint32_t count)
|
||||
{
|
||||
uint64_t repsFp32 = count >> 6; // 6 is count / elemPerRefFp32
|
||||
uint64_t offsetsFp32 = repsFp32 << 6; // 6 is repsFp32 * elemPerRefFp32
|
||||
uint64_t remsFp32 = count & 0x3f; // 0x3f 63, count % elemPerRefFp32
|
||||
const uint64_t elemPerRefFp32 = 64UL; // 256 bit / sizeof(float)
|
||||
if (likely(repsFp32 > 1)) {
|
||||
// 8 is rep stride
|
||||
Max(srcLocal, srcLocal[elemPerRefFp32], srcLocal, elemPerRefFp32, repsFp32 - 1, {1, 1, 1, 0, 8, 0});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(remsFp32 > 0) && unlikely(offsetsFp32 > 0)) {
|
||||
Max(srcLocal, srcLocal[offsetsFp32], srcLocal, remsFp32, 1, {1, 1, 1, 0, 8, 0});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
uint32_t mask = (repsFp32 > 0) ? elemPerRefFp32 : count;
|
||||
// 8 is rep stride
|
||||
WholeReduceMax(srcLocal, srcLocal, mask, 1, 8, 1, 8);
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::QuantProcess()
|
||||
{
|
||||
float dynamicScale = 0.0;
|
||||
LocalTensor<float> floatLocalTemp;
|
||||
floatLocalTemp = tokenCastFloatBuf.Get<float>();
|
||||
|
||||
Cast(floatLocalTemp, xInTensor, RoundMode::CAST_NONE, h);
|
||||
xInQueue.FreeTensor<XType>(xInTensor);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
if constexpr (DynamicQuant) {
|
||||
LocalTensor<float> floatLocalAbsTemp = tokenAbsFloatBuf.Get<float>();
|
||||
|
||||
Abs(floatLocalAbsTemp, floatLocalTemp, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceMaxInplace(floatLocalAbsTemp, h);
|
||||
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
dynamicScale = float(127.0) / (floatLocalAbsTemp.GetValue(0) + 1e-12f);
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Muls(floatLocalTemp, floatLocalTemp, dynamicScale, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
LocalTensor<half> halfLocalTemp = floatLocalTemp.ReinterpretCast<half>();
|
||||
LocalTensor<int32_t> int32LocalTemp = floatLocalTemp.ReinterpretCast<int32_t>();
|
||||
Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, h);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, h);
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xOutTensor, halfLocalTemp, RoundMode::CAST_TRUNC, h);
|
||||
|
||||
floatLocalTemp = xOutTensor.template ReinterpretCast<float>();
|
||||
floatLocalTemp.SetValue(hUBAlignSize / sizeof(float), float(1.0) / dynamicScale); // int8->float32
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::FillTriple(
|
||||
LocalTensor<ExpandXOutType> &xOutTensor, uint32_t tokenIndex, uint32_t k)
|
||||
{
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
LocalTensor<int32_t> xOutTint32 = xOutTensor.template ReinterpretCast<int32_t>();
|
||||
xOutTint32(expandIdxStartIdx) = epRankId;
|
||||
xOutTint32(expandIdxStartIdx + 1) = tokenIndex;
|
||||
xOutTint32(expandIdxStartIdx + 2) = k;
|
||||
SyncFunc<AscendC::HardEvent::S_MTE3>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::InputToShare()
|
||||
{
|
||||
DataCopyExtParams sendOffsetParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> sendOffsetCopyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(sendOffsetTensor, sendOffsetGT, sendOffsetParams, sendOffsetCopyPadParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
uint32_t startTokenId, endTokenId, sendTokenNum, remainTokenNum;
|
||||
sendTokenNum = expertIdsCnt / blockNum;
|
||||
remainTokenNum = expertIdsCnt % blockNum;
|
||||
startTokenId = sendTokenNum * blockIdx;
|
||||
if (blockIdx < remainTokenNum) {
|
||||
sendTokenNum += 1;
|
||||
startTokenId += blockIdx;
|
||||
} else {
|
||||
startTokenId += remainTokenNum;
|
||||
}
|
||||
endTokenId = startTokenId + sendTokenNum;
|
||||
|
||||
if (startTokenId >= expertIdsCnt) {
|
||||
return;
|
||||
}
|
||||
tpipe_->InitBuffer(expertIdsBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48
|
||||
tpipe_->InitBuffer(sendTokenIdxBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48
|
||||
expertIdsTensor = expertIdsBuf.Get<int32_t>();
|
||||
sendTokenIdxTensor = sendTokenIdxBuf.Get<int32_t>();
|
||||
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams sendTokenIdxParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<XType> tokenCopyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(expertIdsTensor, expertIdsGT[startTokenId], expertIdsCntParams, copyPadExtParams);
|
||||
DataCopyPad(sendTokenIdxTensor, sendTokenIdxGT[startTokenId], sendTokenIdxParams, copyPadExtParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
DataCopyExtParams xCopyParams = {1U, static_cast<uint32_t>(h * sizeof(XType)), 0U, 0U, 0U};
|
||||
for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) {
|
||||
uint32_t dstExpertId = expertIdsTensor(tokenIndex - startTokenId);
|
||||
int32_t curExpertCnt = sendTokenIdxTensor(tokenIndex - startTokenId);
|
||||
int32_t dstExpertOffset = sendOffsetTensor(dstExpertId);
|
||||
GM_ADDR rankGM =
|
||||
(__gm__ uint8_t *)(shareGM + hOutGMAlignSize * (dstExpertOffset + curExpertCnt));
|
||||
dstGT.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM);
|
||||
|
||||
if constexpr (DynamicQuant) {
|
||||
xInTensor = xInQueue.AllocTensor<XType>();
|
||||
DataCopyPad(xInTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams);
|
||||
xInQueue.EnQue(xInTensor);
|
||||
xInTensor = xInQueue.DeQue<XType>();
|
||||
xOutTensor = xOutQueue.AllocTensor<ExpandXOutType>();
|
||||
QuantProcess();
|
||||
xOutQueue.EnQue(xOutTensor);
|
||||
xOutTensor = xOutQueue.DeQue<ExpandXOutType>();
|
||||
FillTriple(xOutTensor, tokenIndex / topK, tokenIndex % topK);
|
||||
DataCopyPad(dstGT, xOutTensor, hCommuCopyOutParams);
|
||||
xOutQueue.FreeTensor(xOutTensor);
|
||||
} else {
|
||||
xTmpTensor = xQueue.AllocTensor<ExpandXOutType>();
|
||||
DataCopyPad(xTmpTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams);
|
||||
xQueue.EnQue(xTmpTensor);
|
||||
xTmpTensor = xQueue.DeQue<ExpandXOutType>();
|
||||
FillTriple(xTmpTensor, tokenIndex / topK, tokenIndex % topK);
|
||||
DataCopyPad(dstGT, xTmpTensor, hCommuCopyOutParams);
|
||||
xQueue.FreeTensor<ExpandXOutType>(xTmpTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::SetStatus()
|
||||
{
|
||||
uint32_t startExpId, endExpId, expNumPerCore;
|
||||
expNumPerCore = statusNumPerCore;
|
||||
startExpId = startStatusId;
|
||||
endExpId = endStatusId;
|
||||
if (startExpId > moeExpertNum) {
|
||||
SyncAll<true>();
|
||||
return;
|
||||
}
|
||||
uint32_t statusCntAlign = Ceil(expNumPerCore, 8) * 8;
|
||||
tpipe_->InitBuffer(statusBuf, statusCntAlign * UB_ALIGN); // moeNum / 48 * 32
|
||||
statusTensor = statusBuf.Get<int32_t>();
|
||||
Duplicate<int32_t>(statusTensor, 0, expNumPerCore * 8);
|
||||
uint64_t mask[2] = {0x101010101010101, 0};
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate<int32_t>(statusTensor, 0x3F800000, mask, statusCntAlign / 8, 1, 8);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
SyncAll<true>();
|
||||
for (uint32_t i = startExpId; i < endExpId; ++i) {
|
||||
uint32_t targetRankId = i / moeExpertNumPerRank;
|
||||
uint32_t offset = stateOffset * (epRankId + i % moeExpertNumPerRank * epRankSize);
|
||||
GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, targetRankId) + offset);
|
||||
dstStatusGT.SetGlobalBuffer((__gm__ int32_t *)rankGM);
|
||||
DataCopy<int32_t>(dstStatusGT, statusTensor[(i - startExpId) * 8], 8UL);
|
||||
}
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::WaitStatus()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
uint32_t waitStatusBufSize = (((statusNumPerCore * UB_ALIGN) > 256) ? (statusNumPerCore * UB_ALIGN) : 256);
|
||||
tpipe_->InitBuffer(waitStatusBuf, waitStatusBufSize); // moeNum /48 * 32B = 43 * 32B
|
||||
tpipe_->InitBuffer(gatherMaskOutBuf, moeExpertNum * sizeof(float)); // moeNum * 4B
|
||||
tpipe_->InitBuffer(scalarBuf, UB_ALIGN * 3); // 96B
|
||||
tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 28K
|
||||
tpipe_->InitBuffer(recvOffsetBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B
|
||||
tpipe_->InitBuffer(recvCountBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B
|
||||
|
||||
recvOffsetTensor = recvOffsetBuf.Get<int32_t>();
|
||||
recvCountTensor = recvCountBuf.Get<int32_t>();
|
||||
DataCopyExtParams recvOffsetParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams recvCountParams = {1U, static_cast<uint32_t>(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(recvOffsetTensor, recvOffsetGT, recvOffsetParams, copyPadExtParams);
|
||||
DataCopyPad(recvCountTensor, recvCountGT, recvCountParams, copyPadExtParams);
|
||||
|
||||
if (startStatusId >= moeExpertNum) {
|
||||
SyncAll<true>();
|
||||
return;
|
||||
}
|
||||
|
||||
LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf.Get<float>();
|
||||
LocalTensor<float> statusSumOutTensor = scalarBuf.GetWithOffset<float>(UB_ALIGN / sizeof(float), UB_ALIGN);
|
||||
LocalTensor<float> statusFp32Tensor = waitStatusBuf.Get<float>();
|
||||
GlobalTensor<float> windowInstatusFp32Tensor;
|
||||
windowInstatusFp32Tensor.SetGlobalBuffer((__gm__ float *)(GetWindStateAddrByRankId(COMM_EP_IDX, epRankId)));
|
||||
uint32_t mask = 1;
|
||||
float compareTarget = static_cast<float>(1.0) * statusNumPerCore;
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
DataCopyParams intriParams{static_cast<uint16_t>(statusNumPerCore), 1, 0, 0};
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
while (sumOfFlag != compareTarget) {
|
||||
DataCopy(statusFp32Tensor, windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)], intriParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
ReduceSum(statusSumOutTensor, statusFp32Tensor, gatherMaskOutTensor, mask, statusNumPerCore, 1);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||
}
|
||||
|
||||
// Clear state
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
DataCopyParams intriOutParams{static_cast<uint16_t>(statusNumPerCore), 1, 0, 0};
|
||||
uint64_t duplicateMask[2] = {0x101010101010101, 0};
|
||||
LocalTensor<int32_t> cleanStateTensor = waitStatusBuf.Get<int32_t>();
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Duplicate<int32_t>(cleanStateTensor, 0, duplicateMask, Ceil(statusNumPerCore, 8), 1, 8);
|
||||
SyncFunc<AscendC::HardEvent::V_MTE3>();
|
||||
DataCopy(windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)],
|
||||
cleanStateTensor.ReinterpretCast<float>(),
|
||||
intriOutParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
SyncAll<true>();
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::ShareToOutput()
|
||||
{
|
||||
if (startStatusId >= moeExpertNum) {
|
||||
return;
|
||||
}
|
||||
uint32_t fromRank, count, preCount, recvOffset, targetOffset;
|
||||
DataCopyPadExtParams<ExpandXOutType> copyPadExtParams{false, 0U, 0U, 0U};
|
||||
DataCopyExtParams dataCopyExandIdxParams{1U, sizeof(int32_t) * EXPAND_IDX_INFO, 0U, 0U, 0U};
|
||||
DataCopyExtParams dataCopyOutParams{1U, static_cast<uint32_t>(statusNumPerCore * sizeof(int32_t)), 0U, 0U, 0U};
|
||||
DataCopyExtParams expandXCopyParams = {1U, static_cast<uint32_t>(h * sizeof(ExpandXOutType)), 0U, 0U, 0U};
|
||||
LocalTensor<int32_t> xTmpTensorInt;
|
||||
AscendC::TQueSync<PIPE_MTE2, PIPE_S> recvCountLocalSync;
|
||||
recvCountLocalSync.SetFlag(0);
|
||||
recvCountLocalSync.WaitFlag(0);
|
||||
for (uint32_t i = startStatusId; i < endStatusId; ++i) {
|
||||
preCount = 0;
|
||||
if (likely(i != 0)) {
|
||||
preCount = recvCountTensor(i - 1);
|
||||
}
|
||||
fromRank = i % epRankSize;
|
||||
count = recvCountTensor(i) - preCount;
|
||||
recvOffset = recvOffsetTensor(i);
|
||||
targetOffset = preCount;
|
||||
GM_ADDR recvStart =
|
||||
(__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, fromRank)) + recvOffset * hOutGMAlignSize;
|
||||
GlobalTensor<ExpandXOutType> srcTokenGT, dstTokenGT;
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
srcTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(recvStart + j * hOutGMAlignSize));
|
||||
xTmpTensor = xQueue.AllocTensor<ExpandXOutType>();
|
||||
DataCopyPad(xTmpTensor, srcTokenGT, hCommuCopyOutParams, copyPadExtParams);
|
||||
xQueue.EnQue(xTmpTensor);
|
||||
xTmpTensor = xQueue.DeQue<ExpandXOutType>();
|
||||
xTmpTensorInt = xTmpTensor.template ReinterpretCast<int32_t>();
|
||||
DataCopyPad(expandIdxOutGT[(targetOffset + j) * EXPAND_IDX_INFO],
|
||||
xTmpTensorInt[expandIdxStartIdx],
|
||||
dataCopyExandIdxParams);
|
||||
if constexpr (DynamicQuant) {
|
||||
DataCopyExtParams floatDataCopyParams = {1U, sizeof(float), 0U, 0U, 0U};
|
||||
LocalTensor<float> xOutFp32Tensor = xTmpTensor.template ReinterpretCast<float>();
|
||||
DataCopyPad(dynamicScalesOutGT[targetOffset + j],
|
||||
xOutFp32Tensor[hUBAlignSize / sizeof(float)],
|
||||
floatDataCopyParams);
|
||||
}
|
||||
dstTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM) + (targetOffset + j) * h, h);
|
||||
DataCopyPad(dstTokenGT, xTmpTensor, expandXCopyParams);
|
||||
xQueue.FreeTensor(xTmpTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <CamTypeClass>
|
||||
__aicore__ inline void MoeDispatchNormal<CamTypeFunc>::Process()
|
||||
{
|
||||
if ASCEND_IS_AIV {
|
||||
InputToShare();
|
||||
SetStatus();
|
||||
WaitStatus();
|
||||
ShareToOutput();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MoeDispatchNormalImpl
|
||||
#endif
|
||||
@@ -0,0 +1,30 @@
|
||||
#ifndef MOE_DISPATCH_NORMAL_TILING_H
|
||||
#define MOE_DISPATCH_NORMAL_TILING_H
|
||||
|
||||
struct MoeDispatchNormalInfo {
|
||||
uint32_t epWorldSize; // epWorldSize
|
||||
uint32_t tpWorldSize; // tpWorldSize
|
||||
uint32_t epRankId; // epRankId
|
||||
uint32_t tpRankId; // tpRankId
|
||||
uint32_t moeExpertNum; // moe expert number
|
||||
uint32_t quantMode; // quant mode
|
||||
uint32_t globalBs; // globalBs = BS * worldSize
|
||||
uint32_t bs; // bs
|
||||
uint32_t k; // k
|
||||
uint32_t h; // h
|
||||
uint32_t aivNum; // aivNum
|
||||
bool isQuant; // whether quant or not
|
||||
bool reserved2; // reserved
|
||||
bool reserved3; // reserved
|
||||
uint64_t totalUbSize; // epWorldSize
|
||||
uint64_t totalWinSize;
|
||||
};
|
||||
|
||||
struct MoeDispatchNormalTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
Mc2CcTiling mc2CcTiling2;
|
||||
MoeDispatchNormalInfo moeDispatchNormalInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user