[Kernel] add custom moe ops for prefill (#4194)
### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.
- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.
2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.
- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
49
csrc/notify_dispatch/op_host/CMakeLists.txt
Normal file
49
csrc/notify_dispatch/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME NotifyDispatch
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
notify_dispatch.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_notify_dispatch.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
notify_dispatch_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_notify_dispatch.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
84
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp
Normal file
84
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_notify_dispatch.h"
|
||||
|
||||
extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
extern aclnnStatus aclnnInnerNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerNotifyDispatchGetWorkspaceSize(
|
||||
sendData,
|
||||
tokenPerExpertData,
|
||||
sendCount,
|
||||
numTokens,
|
||||
commGroup,
|
||||
rankSize,
|
||||
rankId,
|
||||
localRankSize,
|
||||
localRankId,
|
||||
sendDataOffset,
|
||||
recvData,
|
||||
workspaceSize,
|
||||
executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerNotifyDispatch(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
61
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h
Normal file
61
csrc/notify_dispatch/op_host/aclnn_notify_dispatch.h
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
#ifndef ACLNN_NOTIFY_DISPATCH_H_
|
||||
#define ACLNN_NOTIFY_DISPATCH_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* funtion: aclnnNotifyDispatchGetWorkspaceSize
|
||||
* parameters :
|
||||
* sendData : required
|
||||
* tokenPerExpertData : required
|
||||
* sendCount : required
|
||||
* numTokens : required
|
||||
* commGroup : required
|
||||
* rankSize : required
|
||||
* rankId : required
|
||||
* localRankSize : required
|
||||
* localRankId : required
|
||||
* sendDataOffset : required
|
||||
* recvData : required
|
||||
* workspaceSize : size of workspace(output).
|
||||
* executor : executor context(output).
|
||||
*/
|
||||
__attribute__((visibility("default")))
|
||||
aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(
|
||||
const aclTensor *sendData,
|
||||
const aclTensor *tokenPerExpertData,
|
||||
int64_t sendCount,
|
||||
int64_t numTokens,
|
||||
char *commGroup,
|
||||
int64_t rankSize,
|
||||
int64_t rankId,
|
||||
int64_t localRankSize,
|
||||
int64_t localRankId,
|
||||
const aclTensor *sendDataOffset,
|
||||
const aclTensor *recvData,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
/* funtion: aclnnNotifyDispatch
|
||||
* parameters :
|
||||
* workspace : workspace memory addr(input).
|
||||
* workspaceSize : size of workspace(input).
|
||||
* executor : executor context(input).
|
||||
* stream : acl stream.
|
||||
*/
|
||||
__attribute__((visibility("default")))
|
||||
aclnnStatus aclnnNotifyDispatch(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
60
csrc/notify_dispatch/op_host/notify_dispatch.cpp
Normal file
60
csrc/notify_dispatch/op_host/notify_dispatch.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class NotifyDispatch : public OpDef {
|
||||
public:
|
||||
explicit NotifyDispatch(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("sendData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("tokenPerExpertData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("sendDataOffset")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("recvData")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("sendCount").Int();
|
||||
this->Attr("num_tokens").Int();
|
||||
this->Attr("comm_group").String();
|
||||
this->Attr("rank_size").Int();
|
||||
this->Attr("rank_id").Int();
|
||||
this->Attr("local_rank_size").Int();
|
||||
this->Attr("local_rank_id").Int();
|
||||
|
||||
OpAICoreConfig aicore_config_base;
|
||||
aicore_config_base.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
OpAICoreConfig aicore_config_A2 = aicore_config_base;
|
||||
aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false");
|
||||
|
||||
OpAICoreConfig aicore_config = aicore_config_base;
|
||||
aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->AICore().AddConfig("ascend910b", aicore_config_A2);
|
||||
this->MC2().HcclGroup("comm_group");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(NotifyDispatch);
|
||||
} // namespace ops
|
||||
306
csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp
Normal file
306
csrc/notify_dispatch/op_host/notify_dispatch_tiling.cpp
Normal file
@@ -0,0 +1,306 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/notify_dispatch_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
#include "experiment/platform/platform/platform_infos_def.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
class Mc2TilingUtils {
|
||||
public:
|
||||
#define HCCL_BUFFSIZE "HCCL_BUFFSIZE"
|
||||
static uint64_t GetMaxWindowSize()
|
||||
{
|
||||
uint16_t defaultWindowSize = 200;
|
||||
if (getenv(HCCL_BUFFSIZE) == nullptr) {
|
||||
OPS_LOG_D("", "Env HCCL_BUFFSIZE don't set");
|
||||
} else {
|
||||
try {
|
||||
std::string envStr(getenv(HCCL_BUFFSIZE));
|
||||
defaultWindowSize = std::stoi(envStr);
|
||||
} catch (const std::invalid_argument &ia) {
|
||||
OPS_LOG_E("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
|
||||
} catch (const std::out_of_range &oor) {
|
||||
OPS_LOG_E("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
|
||||
}
|
||||
}
|
||||
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
|
||||
OPS_LOG_I("", "Get maxWindowSize is %lu", maxWindowSize);
|
||||
return maxWindowSize;
|
||||
}
|
||||
};
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll
|
||||
|
||||
constexpr uint32_t INPUT_SEND_DATA_INDEX = 0;
|
||||
constexpr uint32_t INPUT_TOKEN_PER_EXPERT_INDEX = 1;
|
||||
|
||||
constexpr uint32_t OUTPUT_SEND_DATA_OFFSET_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_RECV_DATA_INDEX = 1;
|
||||
|
||||
constexpr uint32_t ATTR_SEND_COUNT_INDEX = 0;
|
||||
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 1;
|
||||
constexpr uint32_t ATTR_COMM_GROUP_INDEX = 2;
|
||||
constexpr uint32_t ATTR_RANK_SIZE_INDEX = 3;
|
||||
constexpr uint32_t ATTR_RANK_ID_INDEX = 4;
|
||||
constexpr uint32_t ATTR_LOCAL_RANK_SIZE_INDEX = 5;
|
||||
constexpr uint32_t ATTR_LOCAL_RANK_ID_INDEX = 6;
|
||||
|
||||
const size_t MAX_GROUP_NAME_LENGTH = 128UL;
|
||||
const int64_t MAX_COMM_WORLD_SIZE = 384;
|
||||
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;
|
||||
constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes
|
||||
constexpr uint64_t MB_SIZE = 1024UL * 1024UL;
|
||||
|
||||
constexpr static int TILING_KEY_FLOAT16 = 20;
|
||||
constexpr static int TILING_KEY_BFLOAT16 = 21;
|
||||
constexpr static int TILING_KEY_FLOAT = 22;
|
||||
constexpr static int TILING_KEY_INT = 23;
|
||||
constexpr static int TILING_KEY_A2_TYPE = 100;
|
||||
|
||||
constexpr static int ALL_TO_ALL_CORE_NUM = 32;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, NotifyDispatchTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "rankSize is %u.", tilingData.notifyDispatchInfo.rankSize);
|
||||
OPS_LOG_D(nodeName, "rankId is %u.", tilingData.notifyDispatchInfo.rankId);
|
||||
OPS_LOG_D(nodeName, "localRankSize is %u.", tilingData.notifyDispatchInfo.localRankSize);
|
||||
OPS_LOG_D(nodeName, "localRankId is %u.", tilingData.notifyDispatchInfo.localRankId);
|
||||
OPS_LOG_D(nodeName, "sendCount is %u.", tilingData.notifyDispatchInfo.sendCount);
|
||||
OPS_LOG_D(nodeName, "numTokens is %u.", tilingData.notifyDispatchInfo.numTokens);
|
||||
OPS_LOG_D(nodeName, "aivNum is %u.", tilingData.notifyDispatchInfo.aivNum);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.notifyDispatchInfo.totalUbSize);
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
NotifyDispatchTilingData &tilingData, std::string &commGroup)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto sendCountPtr = attrs->GetAttrPointer<int64_t>(ATTR_SEND_COUNT_INDEX);
|
||||
auto numTokenPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_TOKENS_INDEX);
|
||||
auto commGroupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_COMM_GROUP_INDEX));
|
||||
auto rankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_SIZE_INDEX);
|
||||
auto rankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_ID_INDEX);
|
||||
auto localRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_LOCAL_RANK_SIZE_INDEX);
|
||||
auto localRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_LOCAL_RANK_ID_INDEX);
|
||||
|
||||
OPS_CHECK((commGroupPtr == nullptr) || (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == 0) ||
|
||||
(strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH),
|
||||
OPS_LOG_E(nodeName, "commGroupPtr is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(sendCountPtr == nullptr, OPS_LOG_E(nodeName, "sendCountPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numTokenPtr == nullptr, OPS_LOG_E(nodeName, "numTokenPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(rankSizePtr == nullptr, OPS_LOG_E(nodeName, "rankSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(rankIdPtr == nullptr, OPS_LOG_E(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(
|
||||
localRankSizePtr == nullptr, OPS_LOG_E(nodeName, "localRankSizePtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(localRankIdPtr == nullptr, OPS_LOG_E(nodeName, "localRankIdPtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_COMM_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName,
|
||||
"rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
|
||||
MAX_COMM_WORLD_SIZE,
|
||||
*rankSizePtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*rankIdPtr < 0) || (*rankIdPtr >= *rankSizePtr),
|
||||
OPS_LOG_E(nodeName, "rankId is invalid, only support [0, %ld), but got rankId=%ld.", *rankSizePtr, *rankIdPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*sendCountPtr <= 0),
|
||||
OPS_LOG_E(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%ld.", *sendCountPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numTokenPtr <= 0),
|
||||
OPS_LOG_E(nodeName, "numTokenPtr is invalid, only support > 0, but got numTokenPtr=%ld.", *numTokenPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
commGroup = std::string(commGroupPtr);
|
||||
tilingData.notifyDispatchInfo.rankSize = static_cast<uint32_t>(*rankSizePtr);
|
||||
tilingData.notifyDispatchInfo.rankId = static_cast<uint32_t>(*rankIdPtr);
|
||||
tilingData.notifyDispatchInfo.localRankSize = static_cast<uint32_t>(*localRankSizePtr);
|
||||
tilingData.notifyDispatchInfo.localRankId = static_cast<uint32_t>(*localRankIdPtr);
|
||||
tilingData.notifyDispatchInfo.sendCount = static_cast<uint32_t>(*sendCountPtr);
|
||||
tilingData.notifyDispatchInfo.numTokens = static_cast<uint32_t>(*numTokenPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context,
|
||||
NotifyDispatchTilingData *tiling, const std::string commGroup)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "NotifyDispatch commGroup = %s", commGroup.c_str());
|
||||
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(commGroup, opType1, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto sendData = context->GetInputDesc(INPUT_SEND_DATA_INDEX);
|
||||
OPS_CHECK(sendData == nullptr, OPS_LOG_E(nodeName, "sendData is null."), return false);
|
||||
OPS_CHECK((sendData->GetDataType() != ge::DT_BF16) && (sendData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(sendData->GetDataType() != ge::DT_FLOAT) && (sendData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"sendData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(sendData->GetDataType())),
|
||||
return false);
|
||||
uint64_t dataSize;
|
||||
if ((sendData->GetDataType() == ge::DT_BF16) || (sendData->GetDataType() == ge::DT_FLOAT16)) {
|
||||
dataSize = 2;
|
||||
} else {
|
||||
dataSize = 4;
|
||||
}
|
||||
auto tokenPerExpertData = context->GetInputDesc(INPUT_TOKEN_PER_EXPERT_INDEX);
|
||||
OPS_CHECK(tokenPerExpertData == nullptr, OPS_LOG_E(nodeName, "tokenPerExpertData is null."), return false);
|
||||
OPS_CHECK((tokenPerExpertData->GetDataType() != ge::DT_BF16) && (tokenPerExpertData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(tokenPerExpertData->GetDataType() != ge::DT_FLOAT) && (tokenPerExpertData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"tokenPerExpertData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(tokenPerExpertData->GetDataType())),
|
||||
return false);
|
||||
|
||||
auto sendDataOffset = context->GetInputDesc(OUTPUT_SEND_DATA_OFFSET_INDEX);
|
||||
OPS_CHECK(sendDataOffset == nullptr, OPS_LOG_E(nodeName, "sendDataOffset is null."), return false);
|
||||
OPS_CHECK((sendDataOffset->GetDataType() != ge::DT_BF16) && (sendDataOffset->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(sendDataOffset->GetDataType() != ge::DT_FLOAT) && (sendDataOffset->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"sendDataOffset datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(sendDataOffset->GetDataType())),
|
||||
return false);
|
||||
|
||||
auto recvData = context->GetInputDesc(OUTPUT_RECV_DATA_INDEX);
|
||||
OPS_CHECK(recvData == nullptr, OPS_LOG_E(nodeName, "recvData is null."), return false);
|
||||
OPS_CHECK((recvData->GetDataType() != ge::DT_BF16) && (recvData->GetDataType() != ge::DT_FLOAT16) &&
|
||||
(recvData->GetDataType() != ge::DT_FLOAT) && (recvData->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName,
|
||||
"recvData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.",
|
||||
static_cast<ge::DataType>(recvData->GetDataType())),
|
||||
return false);
|
||||
|
||||
// Verify the size of the win area
|
||||
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
|
||||
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
|
||||
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount;
|
||||
if (actualSize > maxWindowSize) {
|
||||
OPS_LOG_E(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckTensor(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus NotifyDispatchTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string commGroup = "";
|
||||
OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func.");
|
||||
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, commGroup) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, commGroup);
|
||||
|
||||
int tilingKey = TILING_KEY_INT;
|
||||
auto sendDtype = context->GetInputDesc(0)->GetDataType();
|
||||
if (sendDtype == ge::DT_FLOAT16) {
|
||||
tilingKey = TILING_KEY_FLOAT16;
|
||||
} else if (sendDtype == ge::DT_BF16) {
|
||||
tilingKey = TILING_KEY_BFLOAT16;
|
||||
} else if (sendDtype == ge::DT_FLOAT) {
|
||||
tilingKey = TILING_KEY_FLOAT;
|
||||
}
|
||||
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
|
||||
|
||||
if (socVersion == "Ascend910B") {
|
||||
tilingKey = tilingKey + TILING_KEY_A2_TYPE;
|
||||
}
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t blockDim;
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
blockDim = aivNum;
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData->notifyDispatchInfo.totalUbSize = ubSize;
|
||||
tilingData->notifyDispatchInfo.aivNum = aivNum;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus NotifyDispatchTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = NotifyDispatchTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct NotifyDispatchCompileInfo {};
|
||||
ge::graphStatus TilingParseForNotifyDispatch(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(NotifyDispatch)
|
||||
.Tiling(NotifyDispatchTilingFunc)
|
||||
.TilingParse<NotifyDispatchCompileInfo>(TilingParseForNotifyDispatch);
|
||||
} // namespace optiling
|
||||
57
csrc/notify_dispatch/op_kernel/notify_dispatch.cpp
Normal file
57
csrc/notify_dispatch/op_kernel/notify_dispatch.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "notify_dispatch.h"
|
||||
#include "notify_dispatch_tiling.h"
|
||||
|
||||
#define TILING_KEY_FLOAT16 20
|
||||
#define TILING_KEY_BFLOAT16 21
|
||||
#define TILING_KEY_FLOAT 22
|
||||
#define TILING_KEY_INT 23
|
||||
|
||||
#define KERNEL_USE_WORKSPACE (1 * 1024 * 1024)
|
||||
|
||||
extern "C" __global__ __aicore__ void notify_dispatch(
|
||||
GM_ADDR sendData, GM_ADDR tokenPerExpertData, GM_ADDR sendDataOffset, GM_ADDR recvData, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(NotifyDispatchTilingData);
|
||||
GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tiling);
|
||||
|
||||
// hcomm will set magic later in init
|
||||
uint32_t magic = 1;
|
||||
GM_ADDR commArgs = nullptr;
|
||||
|
||||
int localRank = tilingData.notifyDispatchInfo.localRankId;
|
||||
int localRankSize = tilingData.notifyDispatchInfo.localRankSize;
|
||||
int rank = tilingData.notifyDispatchInfo.rankId;
|
||||
int rankSize = tilingData.notifyDispatchInfo.rankSize;
|
||||
int64_t len = tilingData.notifyDispatchInfo.sendCount;
|
||||
int64_t numTokens = tilingData.notifyDispatchInfo.numTokens;
|
||||
|
||||
GM_ADDR sendDataInput = sendData;
|
||||
GM_ADDR tokenPerExpertDataInput = tokenPerExpertData;
|
||||
GM_ADDR sendDataOffsetOutput = sendDataOffset;
|
||||
GM_ADDR recvDataOutput = recvData;
|
||||
|
||||
// fill in unused args
|
||||
uint32_t extraFlag = 0;
|
||||
GM_ADDR scale = nullptr;
|
||||
int root = 0;
|
||||
int op = 0;
|
||||
int cycleCount = 0;
|
||||
int64_t scaleCount = 0;
|
||||
GM_ADDR offset = nullptr;
|
||||
int blockNum = GetBlockNum();
|
||||
|
||||
if (TILING_KEY_IS(TILING_KEY_FLOAT16)) {
|
||||
NotifyDispatch<float16_t> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_FLOAT)) {
|
||||
NotifyDispatch<float> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
} else if (TILING_KEY_IS(TILING_KEY_INT)) {
|
||||
NotifyDispatch<int> opKernel(rank, rankSize, extraFlag);
|
||||
opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
opKernel.Process();
|
||||
}
|
||||
}
|
||||
495
csrc/notify_dispatch/op_kernel/notify_dispatch.h
Normal file
495
csrc/notify_dispatch/op_kernel/notify_dispatch.h
Normal file
@@ -0,0 +1,495 @@
|
||||
#ifndef NOTIFY_DISPATCH_H
|
||||
#define NOTIFY_DISPATCH_H
|
||||
|
||||
#include <climits>
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "../common/comm_args.h"
|
||||
#include "../common/data_copy.h"
|
||||
#include "../common/sync_collectives.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
#define KERNELS_ARGS_FUN_ALL2ALL() \
|
||||
GM_ADDR sendDataInput, GM_ADDR tokenPerExpertDataInput, GM_ADDR sendDataOffsetOutput, GM_ADDR recvDataOutput, \
|
||||
int64_t len, int64_t numTokens, int op, int root, int cycleCount, GM_ADDR scale, int64_t scaleCount, \
|
||||
GM_ADDR offset, int localRank, int localRankSize, GM_ADDR commArgs, int magic
|
||||
|
||||
#define KERNELS_ARGS_CALL_ALL2ALL() \
|
||||
sendDataInput, tokenPerExpertDataInput, sendDataOffsetOutput, recvDataOutput, len, numTokens, op, root, \
|
||||
cycleCount, scale, scaleCount, offset, localRank, localRankSize, commArgs, magic
|
||||
|
||||
template <typename T>
|
||||
class NotifyDispatch {
|
||||
constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; // Invalid rank
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_X = 24; // Maximum number of cores provided by the producer stage
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_Y = 16; // Maximum number of cores provided by the consumer stage
|
||||
constexpr static int64_t CORE_NUMS_PER_STAGE_Z = 16; // Maximum number of cores provided by the consumer stage 2
|
||||
constexpr static int64_t SHARE_QUE_DEPTH = 1; // Depth of a single shared queue
|
||||
constexpr static int64_t RANK_NUM_PER_NODE = 16;
|
||||
constexpr static int64_t SIO_NUM = 2; // Depth of a single shared queue
|
||||
constexpr static int64_t MAX_CORE_NUM = 48;
|
||||
constexpr static int64_t MAX_RANK_PER_CORE = 8;
|
||||
constexpr static int64_t MULTI_RANK_SIZE = 48;
|
||||
constexpr static int64_t MAX_BUFFER_NUMBER = 10;
|
||||
|
||||
constexpr static int64_t IDLER_CORE = 0; // Idle core
|
||||
constexpr static int64_t PRODUCER_CORE = 1; // Producer group, responsible for writing data to shared memory, input->share, or share->share
|
||||
constexpr static int64_t CONSUMER_CORE = 2; // Consumer group, responsible for reading data from shared memory, share->output
|
||||
constexpr static int64_t CONSUMER_CORE2 = 3;
|
||||
|
||||
public:
|
||||
__aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag)
|
||||
: rank(rank), rankSize(rankSize), extraFlag(extraFlag)
|
||||
{}
|
||||
|
||||
__aicore__ inline void Init(KERNELS_ARGS_FUN_ALL2ALL())
|
||||
{
|
||||
InitSmallFullMesh(KERNELS_ARGS_CALL_ALL2ALL());
|
||||
nodeNum = rankSize / localRankSize;
|
||||
localRankId = rank % localRankSize;
|
||||
localNodeId = rank / localRankSize;
|
||||
perNodeDataNum = GetDataCount(len, nodeNum); // 128K/4 = 32K
|
||||
perRankDataNum = GetDataCount(len, rankSize); // 128K/64 = 2K
|
||||
|
||||
tokenPerExpertDataAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
sendDataOffsetAlignLen = Ceil(numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
sendDataAlignLen = Ceil(numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
|
||||
|
||||
// Initialize core grouping
|
||||
InitCoreGroup();
|
||||
// Initialize data slicing
|
||||
InitDataSlice();
|
||||
|
||||
this->sendDataInput = (__gm__ T *)sendDataInput;
|
||||
this->tokenPerExpertDataInput = (__gm__ int32_t *)tokenPerExpertDataInput;
|
||||
this->sendDataOffsetOutput = (__gm__ T *)sendDataOffsetOutput;
|
||||
this->recvDataOutput = (__gm__ T *)recvDataOutput;
|
||||
sendDataInputGt.SetGlobalBuffer((__gm__ T *)sendDataInput);
|
||||
tokenPerExpertDataInputGt.SetGlobalBuffer((__gm__ int32_t *)tokenPerExpertDataInput);
|
||||
sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput);
|
||||
recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput);
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
if (blockIdx < 1) {
|
||||
AssembleSendData();
|
||||
}
|
||||
SyncAll<true>();
|
||||
if (blockIdx < coreNumPerStageX) {
|
||||
InputToShareSlice();
|
||||
}
|
||||
if (blockIdx < coreNumPerStageY) {
|
||||
ShareToShareSlice();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void InitCoreGroup()
|
||||
{
|
||||
coreNumPerStageY = MAX_CORE_NUM;
|
||||
coreNumPerStageX = MAX_CORE_NUM;
|
||||
rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM;
|
||||
}
|
||||
|
||||
__aicore__ inline void InitDataSlice()
|
||||
{
|
||||
// The producer is responsible for moving the input data of this rank to shared memory, input-->share
|
||||
if (blockIdx < coreNumPerStageX) {
|
||||
ProducerDataSlice();
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProducerDataSlice()
|
||||
{
|
||||
// The ipcQue responsible for the current core
|
||||
writeGt.SetGlobalBuffer((__gm__ T *)(shareAddrs[rank] + IPC_DATA_OFFSET));
|
||||
}
|
||||
|
||||
__aicore__ inline void AssembleSendData()
|
||||
{
|
||||
pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen);
|
||||
pipe.InitBuffer(sendDataBuf, sendDataAlignLen);
|
||||
pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen);
|
||||
|
||||
__ubuf__ int32_t *tokenPerExpertUB = (__ubuf__ int32_t *)get_imm(96);
|
||||
CpGM2UB(tokenPerExpertUB, (__gm__ int32_t *)tokenPerExpertDataInputGt.GetPhyAddr(), tokenPerExpertDataAlignLen);
|
||||
AscendC::SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
|
||||
|
||||
__ubuf__ T *sendDataOffsetUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen);
|
||||
__ubuf__ T *sendDataUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen + sendDataOffsetAlignLen);
|
||||
|
||||
int prefixSum = 0;
|
||||
for (int i = 0; i < numExperts; ++i) {
|
||||
int numTokensExpert = tokenPerExpertUB[i];
|
||||
sendDataUB[i * sendPerGroup] = numTokensExpert;
|
||||
sendDataUB[i * sendPerGroup + 1] = prefixSum;
|
||||
sendDataUB[i * sendPerGroup + 2] = numTokens;
|
||||
sendDataOffsetUB[i] = prefixSum;
|
||||
|
||||
prefixSum += numTokensExpert;
|
||||
}
|
||||
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
|
||||
CpUB2GM((__gm__ T *)sendDataInputGt.GetPhyAddr(), sendDataUB, sendDataAlignLen);
|
||||
CpUB2GM((__gm__ T *)sendDataOffsetOutputGt.GetPhyAddr(), sendDataOffsetUB, sendDataOffsetAlignLen);
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
// copy input to other rank share
|
||||
__aicore__ inline void InputToShareSlice()
|
||||
{
|
||||
__ubuf__ int64_t *inputUB = (__ubuf__ int64_t *)get_imm(0);
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
readGt = sendDataInputGt[copyOffset * perRankDataNum];
|
||||
CpGM2GMPingPong<T>(
|
||||
copyLen * perRankDataNum * sizeof(T), readGt, writeGt[copyOffset * perRankDataNum], COPYONLY);
|
||||
int64_t v = MergeMagicWithValue(magic, 1);
|
||||
*inputUB = v;
|
||||
AscendC::SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
CpUB2GM((__gm__ int64_t *)(shareAddrs[i]) + rank * FLAG_UNIT_INT_NUM, inputUB, sizeof(int64_t));
|
||||
}
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value)
|
||||
{
|
||||
// magic as the high part, eventID as the low part, combined into a value for comparison
|
||||
return (static_cast<int64_t>(static_cast<uint32_t>(magic)) << MAGIC_OFFSET) | static_cast<int64_t>(value);
|
||||
}
|
||||
|
||||
__aicore__ inline void ShareToShareSlice()
|
||||
{
|
||||
__ubuf__ T *inputUB = (__ubuf__ T *)get_imm(96);
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
int checkRank[MAX_RANK_PER_CORE];
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
checkRank[i - copyOffset] = i + rank % copyLen;
|
||||
if (checkRank[i - copyOffset] >= copyOffset + copyLen) {
|
||||
checkRank[i - copyOffset] -= copyLen;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < copyLen; i++) {
|
||||
readGt1[i].SetGlobalBuffer((__gm__ T *)(shareAddrs[checkRank[i]] + IPC_DATA_OFFSET));
|
||||
}
|
||||
sync.WaitSyncFlag(magic, 1, copyOffset, rank, copyLen);
|
||||
for (int i = 0; i < copyLen; i++) {
|
||||
CpGM2GMPingPong<T>(perRankDataNum * sizeof(T),
|
||||
readGt1[i][rank * perRankDataNum],
|
||||
recvDataOutputGt[checkRank[i] * perRankDataNum],
|
||||
COPYONLY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum);
|
||||
__aicore__ inline GM_ADDR GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx);
|
||||
__aicore__ inline int32_t GetMagicValue(void);
|
||||
FORCE_INLINE_AICORE void InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL());
|
||||
template <typename F>
|
||||
FORCE_INLINE_AICORE void SetAtomic(int op);
|
||||
FORCE_INLINE_AICORE void UnsetAtomic(int op);
|
||||
template<HardEvent eventType>
|
||||
FORCE_INLINE_AICORE void SetWaitEvent(event_t eventId);
|
||||
template <typename K, typename U = K>
|
||||
FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor<U>& sendDataInputGt,
|
||||
const GlobalTensor<K>& recvDataOutputGT, int op);
|
||||
|
||||
GlobalTensor<T> sendDataInputGt;
|
||||
GlobalTensor<int> tokenPerExpertDataInputGt;
|
||||
GlobalTensor<T> sendDataOffsetOutputGt;
|
||||
GlobalTensor<T> recvDataOutputGt;
|
||||
GlobalTensor<T> readGt;
|
||||
GlobalTensor<T> writeGt;
|
||||
GlobalTensor<T> readGt1[MAX_BUFFER_NUMBER];
|
||||
GlobalTensor<T> ipcGT;
|
||||
GlobalTensor<int64_t> sendCountMatrixGm;
|
||||
__gm__ T *sendDataInput;
|
||||
__gm__ int *tokenPerExpertDataInput;
|
||||
__gm__ T *sendDataOffsetOutput;
|
||||
__gm__ T *recvDataOutput;
|
||||
int64_t isPad = 0;
|
||||
int64_t maxSliceNum;
|
||||
int64_t revLen = 0;
|
||||
int64_t sendLen = 0;
|
||||
int64_t sliceLen;
|
||||
int64_t perNodeDataNum;
|
||||
int64_t perRankDataNum;
|
||||
int64_t curRankDataNum;
|
||||
int64_t sendOffset[MULTI_RANK_SIZE];
|
||||
int64_t revOffset[MULTI_RANK_SIZE];
|
||||
int64_t inputDataLen[MULTI_RANK_SIZE];
|
||||
|
||||
int64_t nodeNum;
|
||||
int64_t localRankId;
|
||||
int64_t localNodeId;
|
||||
int64_t targetNode;
|
||||
int64_t targetLocalRankIds[2];
|
||||
int64_t queLen;
|
||||
int64_t queSize;
|
||||
int64_t coreNumPerStageX; // Number of cores used per stage
|
||||
int64_t coreNumPerStageY; // Number of cores used per stage
|
||||
int64_t coreNumPerStageZ; // Number of cores used per stage
|
||||
int64_t flagNumPerStage; // Number of synchronization flags used per stage
|
||||
int64_t coreNumPerNode; // Number of cores allocated per node
|
||||
int64_t coreNumPerRank; // Number of cores allocated per rank
|
||||
int64_t rankNumPerCore; // Number of ranks responsible per core
|
||||
int64_t coreGroup; // Functional group of the current core
|
||||
int64_t targetRank[MULTI_RANK_SIZE]; // Ranks responsible by the current core
|
||||
int64_t targetRankX;
|
||||
int64_t targetRankY;
|
||||
|
||||
int64_t queElemLen; // Size of each element in the shared memory queue (in terms of T)
|
||||
|
||||
int64_t copyLen; // Length of the current data slice being copied (in terms of T)
|
||||
|
||||
// for coll
|
||||
int rank;
|
||||
int rankSize;
|
||||
int localRank = 0;
|
||||
int localRankSize = 0;
|
||||
int xRankSize = 0;
|
||||
int yRankSize = 0;
|
||||
int xRankIdx = 0;
|
||||
int yRankIdx = 0;
|
||||
uint32_t extraFlag;
|
||||
int numTokens;
|
||||
int sendPerGroup = 3;
|
||||
int root;
|
||||
int64_t len;
|
||||
int64_t numExperts;
|
||||
int64_t magic;
|
||||
int64_t blockIdx; // Index of the current aicore
|
||||
int64_t blockNum; // Total number of aicores for the current rank
|
||||
int32_t numRanks;
|
||||
int64_t timeout;
|
||||
uint16_t *rootRanks;
|
||||
GM_ADDR scale;
|
||||
GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses
|
||||
__gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
|
||||
GlobalTensor<int64_t> dfx;
|
||||
TPipe pipe;
|
||||
TBuf<QuePosition::VECCALC> tBuf;
|
||||
TBuf<> tokenPerExpertDataBuf;
|
||||
TBuf<> sendDataOffsetBuf;
|
||||
TBuf<> sendDataBuf;
|
||||
|
||||
uint32_t sendDataAlignLen{0};
|
||||
uint32_t tokenPerExpertDataAlignLen{0};
|
||||
uint32_t sendDataOffsetAlignLen{0};
|
||||
|
||||
SyncCollectives sync;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE int64_t NotifyDispatch<T>::GetDataCount(const int64_t dataLen, const int64_t useBlockNum)
|
||||
{
|
||||
return dataLen / useBlockNum;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline GM_ADDR NotifyDispatch<T>::GetWindAddrByRankId(const int32_t rankId, uint8_t ctxIdx)
|
||||
{
|
||||
uint32_t curRankId = rank;
|
||||
#ifdef OPT_RANK_OFFSET
|
||||
#pragma message("use rank offset")
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + rankId * OPT_RANK_OFFSET;
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) +
|
||||
rankId * OPT_RANK_OFFSET;
|
||||
#else
|
||||
if (curRankId == rankId) {
|
||||
return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn);
|
||||
}
|
||||
return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Assign values to winContext_[COMM_EP_IDX] and blockIdx before calling
|
||||
template <typename T>
|
||||
__aicore__ inline int32_t NotifyDispatch<T>::GetMagicValue(void)
|
||||
{
|
||||
int32_t magic = 0;
|
||||
GlobalTensor<int32_t> selfDataStatusTensor;
|
||||
GM_ADDR statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp);
|
||||
selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET));
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfDataStatusTensor[blockIdx * UB_ALIGN_SIZE]);
|
||||
magic = selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE);
|
||||
if (magic <= 0) {
|
||||
magic = 1;
|
||||
}
|
||||
selfDataStatusTensor(blockIdx * UB_ALIGN_SIZE) = magic + 1;
|
||||
return magic;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL2ALL())
|
||||
{
|
||||
this->root = root;
|
||||
this->len = len;
|
||||
this->numExperts = len / sendPerGroup;
|
||||
this->numTokens = numTokens;
|
||||
this->scale = scale;
|
||||
this->localRank = localRank;
|
||||
this->localRankSize = localRankSize;
|
||||
this->xRankSize = localRankSize;
|
||||
this->yRankSize = rankSize / localRankSize;
|
||||
this->xRankIdx = rank % localRankSize;
|
||||
this->yRankIdx = rank / localRankSize;
|
||||
blockIdx = GetBlockIdx();
|
||||
blockNum = GetBlockNum();
|
||||
uint8_t ctxIdx;
|
||||
|
||||
winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
this->magic = GetMagicValue();
|
||||
ctxIdx = COMM_EP_IDX;
|
||||
|
||||
shareAddrs[rank] = GetWindAddrByRankId(rank, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
|
||||
int64_t rankNumPerCore = (rankSize + MAX_CORE_NUM - 1) / MAX_CORE_NUM;
|
||||
int64_t copyOffset = blockIdx * rankNumPerCore;
|
||||
int64_t copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore;
|
||||
if (copyLen > 0) {
|
||||
for (int i = copyOffset; i < copyOffset + copyLen; ++i) {
|
||||
shareAddrs[i] = GetWindAddrByRankId(i, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
}
|
||||
}
|
||||
|
||||
// When the number of cores is more than the number of ranks, each core is responsible for fetching data from a specified rank
|
||||
int coreNumPerRank = blockNum / rankSize; // Calculate the number of cores assigned to read for each rank, e.g., 48 cores 4 ranks, each rank is assigned 12 cores
|
||||
int maxCore = coreNumPerRank * rankSize; // Calculate the maximum number of cores that can be used for reading, cores exceeding this number will not take action
|
||||
if (blockIdx < maxCore) {
|
||||
int readRank = blockIdx / coreNumPerRank; // Calculate the rank to be read based on the block, 48 cores divided into 4 groups
|
||||
shareAddrs[readRank] = GetWindAddrByRankId(readRank, ctxIdx) +
|
||||
(this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET);
|
||||
}
|
||||
|
||||
pipe.InitBuffer(tBuf, UB_SINGLE_TOTAL_SIZE_MAX);
|
||||
|
||||
sync.Init(rank, rankSize, shareAddrs, tBuf);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy data from GM to GM with ping-pong method.
|
||||
* @tparam dataSizeRemain The remaining size of data to be copied.
|
||||
* @tparam K The type of output data.
|
||||
* @tparam U The type of input data.
|
||||
* @param sendDataInputGt The global tensor of send data.
|
||||
* @param recvDataOutputGT The global tensor of recv data.
|
||||
* @param op The operation to be performed during the copy.
|
||||
* @details This function copies data from global memory to global memory using a ping-pong method.
|
||||
* It first checks if the input and output types are the same. If they are, it uses a single buffer.
|
||||
* If they are not, it divides the buffer according to the size ratio of the types and aligns it to 32 bytes.
|
||||
* Then, it sets the atomic operation, waits for the flags, and performs the copy operation.
|
||||
*/
|
||||
template <typename T>
|
||||
template <typename K, typename U>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor<U>& sendDataInputGt,
|
||||
const GlobalTensor<K>& recvDataOutputGT, int op)
|
||||
{
|
||||
// General case (U = K), input/output are the same, share one UB
|
||||
// Only when conversion is needed (U->K), UB will be divided into two parts according to the ratio of sizeof(U):sizeof(K) and aligned to 32 bytes
|
||||
constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX;
|
||||
constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(K) + sizeof(U)) / UB_ALIGN_SIZE * UB_ALIGN_SIZE;
|
||||
constexpr int32_t inputUbBlockSize = std::is_same_v<K, U> ? ubBlockSize : ubAlignNum * sizeof(U);
|
||||
constexpr int32_t outputUbBlockSize = std::is_same_v<K, U> ? ubBlockSize : ubAlignNum * sizeof(K);
|
||||
|
||||
__gm__ U *input = const_cast<__gm__ U *>(sendDataInputGt.GetPhyAddr());
|
||||
__gm__ K *output = const_cast<__gm__ K *>(recvDataOutputGT.GetPhyAddr());
|
||||
__ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET), (__ubuf__ U*)(UB_MID_OFFSET)};
|
||||
__ubuf__ K* outputUB[2] = {(__ubuf__ K*)inputUB[0], (__ubuf__ K*)inputUB[1]};
|
||||
if constexpr (!std::is_same_v<K, U>) {
|
||||
outputUB[0] = (__ubuf__ K*)(inputUB[0] + inputUbBlockSize / sizeof(U));
|
||||
outputUB[1] = (__ubuf__ K*)(inputUB[1] + inputUbBlockSize / sizeof(U));
|
||||
}
|
||||
int inputOffsetNum = 0;
|
||||
int outputOffsetNum = 0;
|
||||
if (dataSizeRemain <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
SetAtomic<K>(op);
|
||||
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID0); // MTE2 waits for MTE3
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1); // MTE2 waits for MTE3
|
||||
for (int64_t i = 0; dataSizeRemain > 0; i++) {
|
||||
// size and dataSizeRemain both refer to the output size
|
||||
uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain;
|
||||
event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1;
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(eventId);
|
||||
CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(K) * sizeof(U));
|
||||
if constexpr (!std::is_same_v<K, U>) {
|
||||
SetWaitEvent<HardEvent::MTE2_V>(eventId);
|
||||
CastImpl((i & 1) ? outputUB[0] : outputUB[1], (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE,
|
||||
size / sizeof(K));
|
||||
SetWaitEvent<HardEvent::V_MTE3>(eventId);
|
||||
}
|
||||
AscendC::SetFlag<HardEvent::MTE2_MTE3>(eventId);
|
||||
AscendC::WaitFlag<HardEvent::MTE2_MTE3>(eventId);
|
||||
CpUB2GM(output + outputOffsetNum, (i & 1) ? outputUB[0] : outputUB[1], size);
|
||||
AscendC::SetFlag<HardEvent::MTE3_MTE2>(eventId);
|
||||
|
||||
dataSizeRemain -= size;
|
||||
inputOffsetNum += (size / sizeof(K));
|
||||
outputOffsetNum += (size / sizeof(K));
|
||||
}
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID0); // MTE2 waits for MTE3
|
||||
AscendC::WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1); // MTE2 waits for MTE3
|
||||
|
||||
AscendC::SetFlag<HardEvent::MTE3_S>(EVENT_ID3); // Scalar waits for MTE3
|
||||
AscendC::WaitFlag<HardEvent::MTE3_S>(EVENT_ID3);
|
||||
|
||||
UnsetAtomic(op);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename F>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::SetAtomic(int op)
|
||||
{
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
if (op != -1) {
|
||||
#ifdef __DAV_C220_VEC__
|
||||
SetAtomicOpType<F>(op);
|
||||
#endif
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::UnsetAtomic(int op)
|
||||
{
|
||||
if (op != -1) {
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<HardEvent eventType>
|
||||
FORCE_INLINE_AICORE void NotifyDispatch<T>::SetWaitEvent(event_t eventId)
|
||||
{
|
||||
AscendC::SetFlag<eventType>(eventId);
|
||||
AscendC::WaitFlag<eventType>(eventId);
|
||||
}
|
||||
|
||||
#endif // NOTIFY_DISPATCH_H
|
||||
23
csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h
Normal file
23
csrc/notify_dispatch/op_kernel/notify_dispatch_tiling.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef NOTIFY_DISPATCH_TILING_H
|
||||
#define NOTIFY_DISPATCH_TILING_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
struct NotifyDispatchInfo {
|
||||
uint32_t rankSize;
|
||||
uint32_t rankId;
|
||||
uint32_t localRankSize;
|
||||
uint32_t localRankId;
|
||||
uint32_t sendCount;
|
||||
uint32_t numTokens;
|
||||
uint32_t aivNum;
|
||||
uint64_t totalUbSize;
|
||||
};
|
||||
|
||||
struct NotifyDispatchTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
NotifyDispatchInfo notifyDispatchInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user