[Kernel] add custom moe ops for prefill (#4194)
### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.
- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.
2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.
- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
49
csrc/dispatch_layout/op_host/CMakeLists.txt
Normal file
49
csrc/dispatch_layout/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME DispatchLayout
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
dispatch_layout.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_dispatch_layout.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
dispatch_layout_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_layout.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
64
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp
Normal file
64
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn_dispatch_layout.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, numTokensPerRank,
|
||||
numTokensPerExpert, isTokenInRank, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerDispatchLayout(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
50
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h
Normal file
50
csrc/dispatch_layout/op_host/aclnn_dispatch_layout.h
Normal file
@@ -0,0 +1,50 @@
|
||||
#ifndef ACLNN_DISPATCH_LAYOUT_H_
|
||||
#define ACLNN_DISPATCH_LAYOUT_H_
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* funtion: aclnnDispatchLayoutGetWorkspaceSize
|
||||
* topkIdx : required
|
||||
* numTokens : required
|
||||
* numRanks : required
|
||||
* numExperts : required
|
||||
* numTopk : required
|
||||
* numTokensPerRank : required
|
||||
* numTokensPerExpert : required
|
||||
* isTokenInRank : required
|
||||
* workspaceSize : size of workspace(output).
|
||||
* executor : executor context(output).
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
|
||||
const aclTensor *topkIdx,
|
||||
int64_t numTokens,
|
||||
int64_t numRanks,
|
||||
int64_t numExperts,
|
||||
int64_t numTopk,
|
||||
const aclTensor *numTokensPerRank,
|
||||
const aclTensor *numTokensPerExpert,
|
||||
const aclTensor *isTokenInRank,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
/* funtion: aclnnDispatchLayout
|
||||
* workspace : workspace memory addr(input).
|
||||
* workspaceSize : size of workspace(input).
|
||||
* executor : executor context(input).
|
||||
* stream : acl stream.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayout(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
51
csrc/dispatch_layout/op_host/dispatch_layout.cpp
Normal file
51
csrc/dispatch_layout/op_host/dispatch_layout.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class DispatchLayout : public OpDef {
|
||||
public:
|
||||
explicit DispatchLayout(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("topkIdx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
this->Attr("num_tokens").Int();
|
||||
this->Attr("num_ranks").Int();
|
||||
this->Attr("num_experts").Int();
|
||||
this->Attr("num_topk").Int();
|
||||
|
||||
this->Output("numTokensPerRank")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("numTokensPerExpert")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("isTokenInRank")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_true")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DispatchLayout);
|
||||
} // namespace ops
|
||||
211
csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp
Normal file
211
csrc/dispatch_layout/op_host/dispatch_layout_tiling.cpp
Normal file
@@ -0,0 +1,211 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/dispatch_layout_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
#include "experiment/platform/platform/platform_infos_def.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0;
|
||||
|
||||
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1;
|
||||
constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2;
|
||||
|
||||
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0;
|
||||
constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1;
|
||||
constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2;
|
||||
constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3;
|
||||
const int64_t MAX_COMM_WORLD_SIZE = 384;
|
||||
const int64_t MAX_MOE_EXPERTS_NUM = 384;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
|
||||
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;
|
||||
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t K_MAX = 16;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &tilingData)
|
||||
{
|
||||
OPS_LOG_D(nodeName, "numToken is %u.", tilingData.dispatchLayoutInfo.numTokens);
|
||||
OPS_LOG_D(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks);
|
||||
OPS_LOG_D(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts);
|
||||
OPS_LOG_D(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk);
|
||||
OPS_LOG_D(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize);
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchLayoutTilingData &tilingData)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_CHECK(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto numTokensPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOKENS_INDEX));
|
||||
auto numRanksPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_RANKS_INDEX));
|
||||
auto numExpertsPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_EXPERTS_INDEX);
|
||||
auto numTopkPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOPK_INDEX));
|
||||
|
||||
OPS_CHECK(numTokensPtr == nullptr, OPS_LOG_E(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numRanksPtr == nullptr, OPS_LOG_E(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numExpertsPtr == nullptr, OPS_LOG_E(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED);
|
||||
OPS_CHECK(numTopkPtr == nullptr, OPS_LOG_E(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
|
||||
OPS_LOG_E(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", MAX_COMM_WORLD_SIZE, *numRanksPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numExpertsPtr <= 0) || (*numExpertsPtr > MAX_MOE_EXPERTS_NUM),
|
||||
OPS_LOG_E(nodeName, "numExperts is invalid, only support (0, %ld], but got numExperts=%ld.", MAX_MOE_EXPERTS_NUM, *numExpertsPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_CHECK((*numTopkPtr <= 0) || (*numTopkPtr > K_MAX),
|
||||
OPS_LOG_E(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
tilingData.dispatchLayoutInfo.numTokens = static_cast<uint32_t>(*numTokensPtr);
|
||||
tilingData.dispatchLayoutInfo.numRanks = static_cast<uint32_t>(*numRanksPtr);
|
||||
tilingData.dispatchLayoutInfo.numExperts = static_cast<uint32_t>(*numExpertsPtr);
|
||||
tilingData.dispatchLayoutInfo.numTopk = static_cast<uint32_t>(*numTopkPtr);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_CHECK(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto topkIdx = context->GetInputDesc(INPUT_TOPK_IDX_INDEX);
|
||||
auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX);
|
||||
auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
|
||||
auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX);
|
||||
|
||||
OPS_CHECK(topkIdx == nullptr, OPS_LOG_E(nodeName, "topkIdx is null."), return false);
|
||||
OPS_CHECK(numTokensPerRank == nullptr, OPS_LOG_E(nodeName, "numTokensPerRank is null."), return false);
|
||||
OPS_CHECK(numTokensPerExpert == nullptr, OPS_LOG_E(nodeName, "numTokensPerExpert is null."), return false);
|
||||
OPS_CHECK(isTokenInRank == nullptr, OPS_LOG_E(nodeName, "isTokenInRank is null."), return false);
|
||||
|
||||
OPS_CHECK((topkIdx->GetDataType() != ge::DT_INT64),
|
||||
OPS_LOG_E(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(topkIdx->GetDataType())), return false);
|
||||
OPS_CHECK((numTokensPerRank->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "numTokensPerRank datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(numTokensPerRank->GetDataType())), return false);
|
||||
OPS_CHECK((numTokensPerExpert->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "numTokensPerExpert datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(numTokensPerExpert->GetDataType())), return false);
|
||||
OPS_CHECK((isTokenInRank->GetDataType() != ge::DT_INT32),
|
||||
OPS_LOG_E(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.",
|
||||
static_cast<ge::DataType>(isTokenInRank->GetDataType())), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckTensorShape(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
const gert::StorageShape *topkIdxStorageShape = context->GetInputShape(INPUT_TOPK_IDX_INDEX);
|
||||
int64_t topkIdxDim0 = topkIdxStorageShape->GetStorageShape().GetDim(0);
|
||||
int64_t topkIdxDim1 = topkIdxStorageShape->GetStorageShape().GetDim(1);
|
||||
|
||||
OPS_CHECK((topkIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS),
|
||||
OPS_LOG_E(nodeName, "topkIdx must be 2-dimension, but get %lu dim.",
|
||||
topkIdxStorageShape->GetStorageShape().GetDimNum()), return false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingCheckTensor(
|
||||
gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
OPS_CHECK(!CheckTensorDataType(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(!CheckTensorShape(context, nodeName),
|
||||
OPS_LOG_E(nodeName, "params dataType is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
DispatchLayoutTilingData *tilingData = context->GetTilingData<DispatchLayoutTilingData>();
|
||||
OPS_CHECK(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_LOG_I(nodeName, "Enter NotifyDispatch tiling check func.");
|
||||
|
||||
OPS_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling check param failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
|
||||
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t blockDim;
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0UL;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
|
||||
blockDim = aivNum;
|
||||
context->SetBlockDim(blockDim);
|
||||
tilingData->dispatchLayoutInfo.totalUbSize = ubSize;
|
||||
OPS_LOG_D(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize);
|
||||
PrintTilingDataInfo(nodeName, *tilingData);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchLayoutTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
|
||||
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
|
||||
|
||||
std::string socVersion;
|
||||
ge::graphStatus ret;
|
||||
ret = DispatchLayoutTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct DispatchLayoutCompileInfo {};
|
||||
ge::graphStatus TilingParseForDispatchLayout(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(DispatchLayout)
|
||||
.Tiling(DispatchLayoutTilingFunc)
|
||||
.TilingParse<DispatchLayoutCompileInfo>(TilingParseForDispatchLayout);
|
||||
} // namespace optiling
|
||||
17
csrc/dispatch_layout/op_kernel/dispatch_layout.cpp
Normal file
17
csrc/dispatch_layout/op_kernel/dispatch_layout.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "kernel_operator.h"
|
||||
#include "dispatch_layout.h"
|
||||
#include "dispatch_layout_tiling.h"
|
||||
|
||||
|
||||
extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert,
|
||||
GM_ADDR isTokenInRank, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchLayoutTilingData);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling);
|
||||
|
||||
TPipe pipe;
|
||||
|
||||
DispatchLayout<int32_t> op;
|
||||
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, workspace, &pipe, &tilingData);
|
||||
op.Process();
|
||||
}
|
||||
153
csrc/dispatch_layout/op_kernel/dispatch_layout.h
Normal file
153
csrc/dispatch_layout/op_kernel/dispatch_layout.h
Normal file
@@ -0,0 +1,153 @@
|
||||
#ifndef DISPATCH_LAYOUT_H
|
||||
#define DISPATCH_LAYOUT_H
|
||||
|
||||
#include <climits>
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "../common/comm_args.h"
|
||||
#include "../common/data_copy.h"
|
||||
#include "../common/sync_collectives.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "dispatch_layout_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace Moe;
|
||||
|
||||
constexpr uint32_t UB_32_ALIGN = 32U;
|
||||
constexpr uint32_t AIV_NUM = 48;
|
||||
|
||||
template <AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class DispatchLayout {
|
||||
|
||||
public:
|
||||
__aicore__ inline DispatchLayout() {};
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank,
|
||||
GM_ADDR workspace, TPipe *pipe, const DispatchLayoutTilingData *tilingData)
|
||||
{
|
||||
numTokens_ = tilingData->dispatchLayoutInfo.numTokens;
|
||||
numRanks_ = tilingData->dispatchLayoutInfo.numRanks;
|
||||
numExperts_ = tilingData->dispatchLayoutInfo.numExperts;
|
||||
numTopk_ = tilingData->dispatchLayoutInfo.numTopk;
|
||||
tpipe_ = pipe;
|
||||
|
||||
coreIdx_ = GetBlockIdx();
|
||||
uint32_t temp = numTokens_ / AIV_NUM;
|
||||
uint32_t restNum = numTokens_ % AIV_NUM;
|
||||
int64_t topkIdxOffset;
|
||||
int64_t isTokenOffset;
|
||||
tempTokens_ = temp;
|
||||
if (coreIdx_ < restNum) {
|
||||
tempTokens_++;
|
||||
}
|
||||
topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN;
|
||||
|
||||
if (coreIdx_ < restNum) {
|
||||
topkIdxOffset = coreIdx_ * topkIdx32AlignIntLen_;
|
||||
isTokenOffset = coreIdx_ * isTokenInRank32AlignIntLen_;
|
||||
} else {
|
||||
topkIdxOffset = restNum * Ceil((tempTokens_ + 1) * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN
|
||||
+ (coreIdx_ - restNum) * topkIdx32AlignIntLen_;
|
||||
isTokenOffset = restNum * Ceil((tempTokens_ + 1) * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN
|
||||
+ (coreIdx_ - restNum) * isTokenInRank32AlignIntLen_;
|
||||
}
|
||||
|
||||
topkIdxGM_.SetGlobalBuffer((__gm__ int64_t*)(topkIdx + topkIdxOffset));
|
||||
numTokensPerRankGM_.SetGlobalBuffer((__gm__ T*)numTokensPerRank);
|
||||
numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T*)numTokensPerExpert);
|
||||
isTokenInRankGM_.SetGlobalBuffer((__gm__ T*)(isTokenInRank + isTokenOffset));
|
||||
|
||||
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_);
|
||||
tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_);
|
||||
tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_);
|
||||
tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_);
|
||||
tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T));
|
||||
|
||||
LocalTensor<int64_t> topkIdxTensor = topkIdxBuf_.AllocTensor<int64_t>();
|
||||
const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<int64_t> padParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
LocalTensor<T> numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> isTokenInRankTensor = isTokenInRankBuf_.AllocTensor<T>();
|
||||
LocalTensor<T> seenRankTensor = seenRankBuf_.AllocTensor<T>();
|
||||
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
|
||||
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
|
||||
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
|
||||
int experts_per_rank = numExperts_ / numRanks_;
|
||||
for (int i = 0; i < tempTokens_; ++i) {
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
Duplicate<T>(seenRankTensor, 0, numRanks_);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
for (int j = 0; j < numTopk_; ++j) {
|
||||
int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j);
|
||||
uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1;
|
||||
numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num);
|
||||
int rank_id = expert_idx / experts_per_rank;
|
||||
if (!seenRankTensor.GetValue(rank_id)) {
|
||||
uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1;
|
||||
isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1);
|
||||
seenRankTensor.SetValue(rank_id, 1);
|
||||
numTokensPerRankTensor.SetValue(rank_id, per_rank_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DataCopyExtParams isTokenInRankDataCopyParams{1U, isTokenInRank32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams);
|
||||
AscendC::SetAtomicAdd<T>();
|
||||
const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams);
|
||||
const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U};
|
||||
DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams);
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
|
||||
private:
|
||||
GlobalTensor<int64_t> topkIdxGM_;
|
||||
GlobalTensor<T> numTokensPerRankGM_;
|
||||
GlobalTensor<T> numTokensPerExpertGM_;
|
||||
GlobalTensor<T> isTokenInRankGM_;
|
||||
|
||||
TBuf<> topkIdxBuf_;
|
||||
TBuf<> numTokensPerRankBuf_;
|
||||
TBuf<> numTokensPerExpertBuf_;
|
||||
TBuf<> isTokenInRankBuf_;
|
||||
TBuf<> seenRankBuf_;
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
uint32_t numTokens_{0};
|
||||
uint32_t numRanks_{0};
|
||||
uint32_t numExperts_{0};
|
||||
uint32_t numTopk_{0};
|
||||
uint32_t coreIdx_{0};
|
||||
uint32_t tempTokens_{0};
|
||||
|
||||
uint32_t topkIdx32AlignIntLen_{0};
|
||||
uint32_t numTokensPerRank32AlignIntLen_{0};
|
||||
uint32_t numTokensPerExpert32AlignIntLen_{0};
|
||||
uint32_t isTokenInRank32AlignIntLen_{0};
|
||||
};
|
||||
|
||||
#endif // DISPATCH_LAYOUT_H
|
||||
20
csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h
Normal file
20
csrc/dispatch_layout/op_kernel/dispatch_layout_tiling.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef DISPATCH_LAYOUT_TILING_H
|
||||
#define DISPATCH_LAYOUT_TILING_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
struct DispatchLayoutInfo {
|
||||
uint32_t numTokens;
|
||||
uint32_t numRanks;
|
||||
uint32_t numExperts;
|
||||
uint32_t numTopk;
|
||||
uint64_t totalUbSize;
|
||||
};
|
||||
|
||||
struct DispatchLayoutTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling1;
|
||||
DispatchLayoutInfo dispatchLayoutInfo;
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user