[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)
What this PR does / why we need it? Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce, Add, and RMSNorm Does this PR introduce _any_ user-facing change? No How was this patch tested? vLLM version: v0.11.2 vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: tongrunze <t00574058@china.huawei.com> Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
@@ -11,7 +11,20 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
exit 0
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
# ASCEND910B (A2) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
# depdendency: catlass
|
||||
git config --global --add safe.directory "$ROOT_DIR"
|
||||
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass/include
|
||||
if [[ ! -d "${CATLASS_PATH}" ]]; then
|
||||
echo "depdendency catlass is missing, try to fetch it..."
|
||||
if ! git submodule update --init --recursive; then
|
||||
echo "fetch failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
|
||||
51
csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt
Normal file
51
csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME MatmulAllreduceAddRmsnormTensorList
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
matmul_allreduce_add_rmsnorm_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_matmul_allreduce_add_rmsnorm.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_matmul_allreduce_add_rmsnorm.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_matmul_allreduce_add_rmsnorm.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
matmul_allreduce_add_rmsnorm_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
matmul_allreduce_add_rmsnorm_proto.cpp
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_matmul_allreduce_add_rmsnorm.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn/opdev/platform.h"
|
||||
#include "aclnn_matmul_allreduce_add_rmsnorm.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnormGetWorkspaceSize(
|
||||
const aclTensor *x1,
|
||||
const aclTensor *x2,
|
||||
const aclTensor *residual,
|
||||
const aclTensor *gamma,
|
||||
char *groupTp,
|
||||
int64_t tpRankSize,
|
||||
int64_t tpRankId,
|
||||
double epsilon,
|
||||
bool isTransB,
|
||||
bool isGatherAddOut,
|
||||
const aclTensor *yOut,
|
||||
const aclTensor *addOutOut,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnorm(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
aclnnStatus aclnnMatmulAllreduceAddRmsnormGetWorkspaceSize(
|
||||
const aclTensor *x1,
|
||||
const aclTensor *x2,
|
||||
const aclTensor *residual,
|
||||
const aclTensor *gamma,
|
||||
char *groupTp,
|
||||
int64_t tpRankSize,
|
||||
int64_t tpRankId,
|
||||
double epsilon,
|
||||
bool isTransB,
|
||||
bool isGatherAddOut,
|
||||
const aclTensor *y,
|
||||
const aclTensor *addOut,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerMatmulAllreduceAddRmsnormGetWorkspaceSize(x1, x2, residual,
|
||||
gamma, groupTp, tpRankSize, tpRankId, epsilon, isTransB, isGatherAddOut, y, addOut, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnMatmulAllreduceAddRmsnorm(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
return aclnnInnerMatmulAllreduceAddRmsnorm(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACLNN_MATMUL_ALLREDUCE_ADD_RMSNORM
|
||||
#define ACLNN_MATMUL_ALLREDUCE_ADD_RMSNORM
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMatmulAllreduceAddRmsnormGetWorkspaceSize(
|
||||
const aclTensor *x1,
|
||||
const aclTensor *x2,
|
||||
const aclTensor *residual,
|
||||
const aclTensor *gamma,
|
||||
char *groupTp,
|
||||
int64_t tpRankSize,
|
||||
int64_t tpRankId,
|
||||
double epsilon,
|
||||
bool isTransB,
|
||||
bool isGatherAddOut,
|
||||
const aclTensor *y,
|
||||
const aclTensor *addOut,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnMatmulAllreduceAddRmsnorm(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops{
|
||||
class MatmulAllreduceAddRmsnorm : public OpDef {
|
||||
public:
|
||||
explicit MatmulAllreduceAddRmsnorm(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("x1")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("x2")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("residual")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gamma")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("y")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("add_out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->Attr("group_tp").String();
|
||||
this->Attr("tp_rank_size").Int();
|
||||
this->Attr("tp_rank_id").Int();
|
||||
this->Attr("epsilon").AttrType(OPTIONAL).Float(1e-6);
|
||||
this->Attr("is_trans_b").AttrType(OPTIONAL).Bool(false);
|
||||
this->Attr("is_gather_add_out").AttrType(OPTIONAL).Bool(false);
|
||||
|
||||
this->MC2().HcclGroup({"group_tp"});
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(MatmulAllreduceAddRmsnorm);
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ge {
|
||||
constexpr uint32_t RESIDUAL_INDEX = 3;
|
||||
constexpr uint32_t OUTPUT_Y_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_ADD_OUT_INDEX = 1;
|
||||
constexpr int SHAPE_INDEX0 = 0;
|
||||
constexpr int SHAPE_INDEX1 = 1;
|
||||
constexpr int SHAPE_INDEX2 = 2;
|
||||
constexpr int DIM_NUM_2 = 2;
|
||||
constexpr int DIM_NUM_3 = 3;
|
||||
|
||||
static void CloneShape(const gert::Shape* src, gert::Shape* dst)
|
||||
{
|
||||
int ndim = src->GetDimNum();
|
||||
dst->SetDimNum(ndim);
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
dst->SetDim(i, src->GetDim(i));
|
||||
}
|
||||
}
|
||||
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext* context)
|
||||
{
|
||||
const gert::Shape* residualShape = context->GetInputShape(RESIDUAL_INDEX);
|
||||
int residualDimNum = residualShape->GetDimNum();
|
||||
|
||||
if (residualDimNum != DIM_NUM_2 && residualDimNum != DIM_NUM_3) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
gert::Shape* x1OutShape = context->GetOutputShape(OUTPUT_Y_INDEX);
|
||||
gert::Shape* addOutShape = context->GetOutputShape(OUTPUT_ADD_OUT_INDEX);
|
||||
CloneShape(residualShape, x1OutShape);
|
||||
CloneShape(residualShape, addOutShape);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
|
||||
{
|
||||
const auto residualDataType = context->GetInputDataType(RESIDUAL_INDEX);
|
||||
context->SetOutputDataType(OUTPUT_Y_INDEX, residualDataType);
|
||||
context->SetOutputDataType(OUTPUT_ADD_OUT_INDEX, residualDataType);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(MatmulAllreduceAddRmsnorm)
|
||||
.InferShape(InferShape)
|
||||
.InferDataType(InferDataType);
|
||||
}
|
||||
@@ -0,0 +1,619 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "../op_kernel/matmul_allreduce_add_rmsnorm_tiling.h"
|
||||
#include "matmul_allreduce_add_rmsnorm_workspace.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
|
||||
typedef enum {
|
||||
ATTR_TP_INDEX = 0,
|
||||
ATTR_RANK_SIZE_INDEX,
|
||||
ATTR_RANK_ID_INDEX,
|
||||
ATTR_EPSILON_INDEX,
|
||||
ATTR_IS_TRANS_B_INDEX,
|
||||
ATTR_IS_GATHER_ADD_OUT_INDEX
|
||||
} ATTR_TYPE;
|
||||
|
||||
int32_t CeilDev(int32_t num, int32_t div)
|
||||
{
|
||||
if (div == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (num + div - 1) / div;
|
||||
}
|
||||
static constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
|
||||
static constexpr uint32_t BATCH_SIZE_ONE = 1;
|
||||
static constexpr uint32_t DEFAULT_ROW = 128;
|
||||
static constexpr uint32_t DEFAULT_COL = 256;
|
||||
static constexpr uint32_t DEFAULT_SWIZZLE_COUNT = 4;
|
||||
static constexpr int32_t VALID_UB_MOVE_NUM = 20480;
|
||||
static constexpr int32_t COMMDATASPLIT_ONE = 1;
|
||||
static constexpr int32_t COMM_DATA_DIRECT = 0;
|
||||
static constexpr uint32_t ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT = 128;
|
||||
static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_DATASPLIT_DEFAULT = 16;
|
||||
static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 100;
|
||||
static constexpr int32_t HALF_KBYTE = 512;
|
||||
static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT = 14;
|
||||
static constexpr int32_t SWIZZLE_DIRECT_ONE = 1;
|
||||
static constexpr int32_t COMMNPUSPLIT_ONE = 1;
|
||||
static constexpr int32_t COMMDATASPLIT_SIXTEEN = 16;
|
||||
constexpr int32_t SECOND_TO_MS = 1000;
|
||||
constexpr int64_t MATMUL_BASE_100US = static_cast<int64_t>(1024) * 8192 * 1024;
|
||||
constexpr int64_t ALLREDUCE_BASE_100US = 4096 * 1024;
|
||||
constexpr double ONE_K = 1024.0;
|
||||
constexpr double B1_FLOP_PER_MS = (364 * 0.8) * 1e9;
|
||||
constexpr double DOUBLE = 2.0;
|
||||
constexpr double HALF_PROB = 0.5;
|
||||
constexpr int32_t CONDITION_M_ST = 0;
|
||||
constexpr int32_t CONDITION_M_END = 1;
|
||||
constexpr int32_t CONDITION_K_ST = 2;
|
||||
constexpr int32_t CONDITION_K_END = 3;
|
||||
constexpr int32_t CONDITION_N_ST = 4;
|
||||
constexpr int32_t CONDITION_N_END = 5;
|
||||
constexpr int32_t RANKSIZE_FOUR = 4;
|
||||
constexpr int32_t RANKSIZE_EIGHT = 8;
|
||||
constexpr int32_t DIV_TWO = 2;
|
||||
constexpr int32_t LENPERLOOP_DEFAULT = 5120;
|
||||
constexpr int32_t MIN_UB_MOVE_NUM = 5120;
|
||||
constexpr int32_t MAX_UB_NUM = 97280; // 190 * 1024 / 2
|
||||
constexpr int32_t MAX_P_VALUE = 15;
|
||||
|
||||
constexpr int32_t DIM_NUM_TWO = 2;
|
||||
constexpr int32_t DIM_NUM_THREE = 3;
|
||||
constexpr int32_t DIM_INDEX_ZERO = 0;
|
||||
constexpr int32_t DIM_INDEX_ONE = 1;
|
||||
constexpr int32_t DIM_INDEX_TWO = 2;
|
||||
|
||||
static constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
|
||||
static constexpr uint32_t USE_CORE_NUM = 20;
|
||||
|
||||
static std::vector<double> ALLREDUCE_UBMOVENUM_COEF = {{-1.72352427e+01,
|
||||
2.56887672e-03,
|
||||
-8.21819480e+00,
|
||||
8.70965589e+01,
|
||||
-3.63853858e-01,
|
||||
1.27789264e+01,
|
||||
1.29782183e+02,
|
||||
1.90250023e-02,
|
||||
-3.48175441e+00,
|
||||
6.18921914e+03,
|
||||
3.77072171e+03,
|
||||
-5.86895290e+01,
|
||||
-8.70740991e-01,
|
||||
-1.40262280e-04,
|
||||
-2.81910331e-08,
|
||||
3.22795486e-05,
|
||||
-4.84522320e-03,
|
||||
2.94839177e-01,
|
||||
2.97260958e-03,
|
||||
9.08844709e+01,
|
||||
-5.80426209e-10,
|
||||
38.183465184603484}};
|
||||
|
||||
static std::map<int, std::vector<std::vector<int>>> ALLREDUCE_EIGHT_RANK_FP16_M0_MAP = {
|
||||
{128,
|
||||
{{-1, 31220, -1, 2147483647, -1, 768},
|
||||
{31220, 36980, 1280, 2147483647, -1, 768},
|
||||
{36980, 2147483647, -1, 2147483647, -1, 768},
|
||||
{-1, 2147483647, -1, 2147483647, 768, 2147483647}}},
|
||||
{256, {{31220, 36980, -1, 1280, -1, 768}}}};
|
||||
|
||||
static std::map<int, std::vector<std::vector<int>>> ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_MAP = {
|
||||
{100,
|
||||
{{-1, 3072, -1, 2147483647, -1, 768},
|
||||
{3072, 19680, -1, 3072, -1, 768},
|
||||
{-1, 3072, -1, 2147483647, 768, 1536},
|
||||
{3072, 19680, -1, 3072, 768, 1536},
|
||||
{-1, 2147483647, 1792, 2976, 1536, 13312}}},
|
||||
{30,
|
||||
{{3072, 19680, 3072, 2147483647, -1, 768},
|
||||
{19680, 2147483647, -1, 3072, -1, 1536},
|
||||
{-1, 2147483647, -1, 1792, 1536, 13312},
|
||||
{-1, 768, 2976, 2147483647, 5376, 13312},
|
||||
{-1, 768, -1, 2147483647, 13312, 2147483647},
|
||||
{26880, 2147483647, -1, 3072, 13312, 2147483647}}},
|
||||
{20,
|
||||
{{3072, 19680, 3072, 2147483647, 768, 1536},
|
||||
{19680, 2147483647, 3072, 2147483647, -1, 1536},
|
||||
{-1, 2147483647, 2976, 2147483647, 1536, 5376},
|
||||
{768, 2147483647, 2976, 2147483647, 5376, 13312},
|
||||
{768, 26880, -1, 2147483647, 13312, 2147483647},
|
||||
{26880, 2147483647, 3072, 2147483647, 13312, 2147483647}}}};
|
||||
|
||||
static std::vector<double> ALLREDUCE_PVALUE_COEF = {{-4.23166350e+00,
|
||||
6.71137487e-04,
|
||||
-1.33434156e+00,
|
||||
1.12915884e+01,
|
||||
-7.85892737e-02,
|
||||
2.59059897e+00,
|
||||
3.22129881e+01,
|
||||
-5.15776887e-02,
|
||||
9.15542742e-01,
|
||||
1.56322201e+03,
|
||||
3.61977421e+01,
|
||||
-5.49544589e-01,
|
||||
-2.66903417e-01,
|
||||
-3.68521920e-05,
|
||||
-6.40666333e-09,
|
||||
6.77406054e-06,
|
||||
-9.92992099e-04,
|
||||
5.60658043e-02,
|
||||
2.69372863e-04,
|
||||
2.17222337e+01,
|
||||
-1.17749660e-10,
|
||||
6.100544547671263}};
|
||||
|
||||
double GetMTETime(double mknGB, int32_t m0, int32_t n0, double aBindWidth = 3.0, double bBindWidth = 3.0)
|
||||
{
|
||||
// 预估Matmul计算的MTE2搬运时间
|
||||
return DOUBLE * mknGB * (SECOND_TO_MS / ONE_K) * (1.0 / (n0 * aBindWidth) + 1.0 / (m0 * bBindWidth));
|
||||
}
|
||||
|
||||
int32_t AllReduceUbMoveNum(int m, int k, int n)
|
||||
{
|
||||
double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40;
|
||||
double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n;
|
||||
double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K);
|
||||
double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL);
|
||||
double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW);
|
||||
double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2);
|
||||
double matmulPredict = std::max(cubePredict, mteTimePredict);
|
||||
double c0 = matmulPredict / commPredict;
|
||||
double c1 = 1.0 * m * n / k;
|
||||
double c2 = sqrt(c1);
|
||||
double c3 = sqrt(1.0 * m * n) / k;
|
||||
double c4 = c3 * c3;
|
||||
double c5 = matmulPredict;
|
||||
double c6 = commPredict;
|
||||
double c7 = 1.0 * n / m;
|
||||
double c8 = 1.0 * m * n / sqrt(k);
|
||||
double c9 = 1.0 * m * n * sqrt(k);
|
||||
double c10 = sqrt(1.0 * m * n) * k;
|
||||
double c11 = sqrt(1.0 * m * n * k);
|
||||
double c12 = sqrt(1.0 * m * n);
|
||||
double c13 = 1.0 * k * k / sqrt(1.0 * m * n);
|
||||
double c14 = 1.0 * k * k * sqrt(1.0 * m * n);
|
||||
double ubMoveNumDouble = 0;
|
||||
std::vector<double> feats_update = {c0,
|
||||
c1,
|
||||
c2,
|
||||
c3,
|
||||
c4,
|
||||
c5,
|
||||
c6,
|
||||
c7,
|
||||
1.0 / c0,
|
||||
1.0 / c1,
|
||||
1.0 / c2,
|
||||
1.0 / c3,
|
||||
1.0 / c4,
|
||||
c8,
|
||||
c9,
|
||||
c10,
|
||||
c11,
|
||||
c12,
|
||||
c13,
|
||||
1.0 / c13,
|
||||
c14,
|
||||
1};
|
||||
for (uint32_t i = 0; i < feats_update.size(); i++) {
|
||||
ubMoveNumDouble += feats_update[i] * ALLREDUCE_UBMOVENUM_COEF[i];
|
||||
}
|
||||
|
||||
return std::min(std::max(static_cast<int32_t>(ubMoveNumDouble) * HALF_KBYTE, MIN_UB_MOVE_NUM), MAX_UB_NUM);
|
||||
}
|
||||
|
||||
int32_t AllReducePValue(int m, int k, int n)
|
||||
{
|
||||
double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40;
|
||||
double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n;
|
||||
double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K);
|
||||
double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL);
|
||||
double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW);
|
||||
double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2);
|
||||
double matmulPredict = std::max(cubePredict, mteTimePredict);
|
||||
double c0 = matmulPredict / commPredict;
|
||||
double c1 = 1.0 * m * n / k;
|
||||
double c2 = sqrt(c1);
|
||||
double c3 = sqrt(1.0 * m * n) / k;
|
||||
double c4 = c3 * c3;
|
||||
double c5 = matmulPredict;
|
||||
double c6 = commPredict;
|
||||
double c7 = 1.0 * n / m;
|
||||
double c8 = 1.0 * m * n / sqrt(k);
|
||||
double c9 = 1.0 * m * n * sqrt(k);
|
||||
double c10 = sqrt(1.0 * m * n) * k;
|
||||
double c11 = sqrt(1.0 * m * n * k);
|
||||
double c12 = sqrt(1.0 * m * n);
|
||||
double c13 = 1.0 * k * k / sqrt(1.0 * m * n);
|
||||
double c14 = 1.0 * k * k * sqrt(1.0 * m * n);
|
||||
double pValueDouble = 0;
|
||||
std::vector<double> feats_update = {c0,
|
||||
c1,
|
||||
c2,
|
||||
c3,
|
||||
c4,
|
||||
c5,
|
||||
c6,
|
||||
c7,
|
||||
1.0 / c0,
|
||||
1.0 / c1,
|
||||
1.0 / c2,
|
||||
1.0 / c3,
|
||||
1.0 / c4,
|
||||
c8,
|
||||
c9,
|
||||
c10,
|
||||
c11,
|
||||
c12,
|
||||
c13,
|
||||
1.0 / c13,
|
||||
c14,
|
||||
1};
|
||||
for (uint32_t i = 0; i < feats_update.size(); i++) {
|
||||
pValueDouble += feats_update[i] * ALLREDUCE_PVALUE_COEF[i];
|
||||
}
|
||||
|
||||
return std::min(std::max(static_cast<int32_t>(pValueDouble), 1), MAX_P_VALUE);
|
||||
}
|
||||
|
||||
int32_t GetValueFromMKNConditionMap(
|
||||
int32_t m, int32_t k, int32_t n, int32_t defaultValue, std::map<int, std::vector<std::vector<int>>> conditionMap)
|
||||
{
|
||||
int32_t value = defaultValue;
|
||||
for (auto &item : conditionMap) {
|
||||
for (auto &condition : item.second) {
|
||||
bool inRange = m > condition[CONDITION_M_ST] && m <= condition[CONDITION_M_END] &&
|
||||
k > condition[CONDITION_K_ST] && k <= condition[CONDITION_K_END] &&
|
||||
n > condition[CONDITION_N_ST] && n <= condition[CONDITION_N_END];
|
||||
if (inRange) {
|
||||
return item.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
void AllReduceEightRankFP16GetDefaultTiling(
|
||||
gert::TilingContext *context, PPTilingData &ppTilingData, CommTilingData &commTilingData)
|
||||
{
|
||||
int32_t m = ppTilingData.opShape.m;
|
||||
int32_t k = ppTilingData.opShape.k;
|
||||
int32_t n = ppTilingData.opShape.n;
|
||||
|
||||
ppTilingData.m0 =
|
||||
GetValueFromMKNConditionMap(m, k, n, ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT, ALLREDUCE_EIGHT_RANK_FP16_M0_MAP);
|
||||
|
||||
ppTilingData.k0 = DEFAULT_COL;
|
||||
ppTilingData.n0 = ppTilingData.m0 == DEFAULT_ROW ? DEFAULT_COL : DEFAULT_ROW;
|
||||
|
||||
ppTilingData.mLoop = CeilDev(m, ppTilingData.m0);
|
||||
ppTilingData.nLoop = CeilDev(n, ppTilingData.n0);
|
||||
ppTilingData.kLoop = CeilDev(k, ppTilingData.k0);
|
||||
|
||||
ppTilingData.coreLoop = ppTilingData.opShape.batchSize * ppTilingData.mLoop * ppTilingData.nLoop;
|
||||
ppTilingData.swizzlDirect = SWIZZLE_DIRECT_ONE;
|
||||
ppTilingData.swizzlCount = DEFAULT_SWIZZLE_COUNT;
|
||||
ppTilingData.tilingKey = 0;
|
||||
ppTilingData.splitK = 0;
|
||||
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
ppTilingData.blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
|
||||
|
||||
commTilingData.ubMoveNum =
|
||||
GetValueFromMKNConditionMap(
|
||||
m, k, n, ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_MAP) *
|
||||
HALF_KBYTE;
|
||||
commTilingData.pValue = ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT;
|
||||
|
||||
commTilingData.commDirect = COMM_DATA_DIRECT;
|
||||
commTilingData.commNpuSplit = COMMNPUSPLIT_ONE;
|
||||
commTilingData.commDataSplit = COMMDATASPLIT_SIXTEEN;
|
||||
commTilingData.is91093 = 0;
|
||||
commTilingData.withSerialMode = 0;
|
||||
commTilingData.tag = 0;
|
||||
commTilingData.write2OtherRank = 0;
|
||||
}
|
||||
|
||||
void GetDefaultTiling(gert::TilingContext *context, PPTilingData &ppTilingData, CommTilingData &commTilingData)
|
||||
{
|
||||
int32_t m = ppTilingData.opShape.m;
|
||||
int32_t k = ppTilingData.opShape.k;
|
||||
int32_t n = ppTilingData.opShape.n;
|
||||
|
||||
ppTilingData.m0 = DEFAULT_ROW;
|
||||
ppTilingData.n0 = DEFAULT_COL;
|
||||
ppTilingData.k0 = DEFAULT_COL;
|
||||
|
||||
ppTilingData.mLoop = CeilDev(m, ppTilingData.m0);
|
||||
ppTilingData.nLoop = CeilDev(n, ppTilingData.n0);
|
||||
ppTilingData.kLoop = CeilDev(k, ppTilingData.k0);
|
||||
ppTilingData.coreLoop = ppTilingData.opShape.batchSize * ppTilingData.mLoop * ppTilingData.nLoop;
|
||||
|
||||
ppTilingData.swizzlDirect = m > n ? 0 : 1;
|
||||
ppTilingData.swizzlCount = DEFAULT_SWIZZLE_COUNT;
|
||||
ppTilingData.tilingKey = 0;
|
||||
ppTilingData.splitK = 0;
|
||||
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
ppTilingData.blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum);
|
||||
|
||||
commTilingData.ubMoveNum = AllReduceUbMoveNum(m, k, n);
|
||||
commTilingData.pValue = AllReducePValue(m, k, n);
|
||||
commTilingData.commNpuSplit = commTilingData.rankSize;
|
||||
commTilingData.commDataSplit = COMMDATASPLIT_ONE;
|
||||
commTilingData.commDirect = COMM_DATA_DIRECT;
|
||||
commTilingData.lenPerLoop = ppTilingData.m0 * ppTilingData.n0 * commTilingData.pValue * ppTilingData.blockDim;
|
||||
commTilingData.lenPerLoop = commTilingData.lenPerLoop / commTilingData.rankSize;
|
||||
commTilingData.is91093 = 0;
|
||||
commTilingData.withSerialMode = 0;
|
||||
commTilingData.tag = 0;
|
||||
commTilingData.write2OtherRank = 0;
|
||||
}
|
||||
|
||||
static inline void GetRmsnormTilingData(RmsNormTilingData &rmsnormtiling, std::vector<int64_t> &shapeVec,
|
||||
std::vector<int64_t> &oriShapeVec, uint32_t calcBytes = 0, uint32_t loopCount = 1, float ep = 1e-5)
|
||||
{
|
||||
ge::Shape srcShape(shapeVec);
|
||||
ge::Shape oriSrcShape(oriShapeVec);
|
||||
uint32_t minValue = 0;
|
||||
uint32_t maxValue = 0;
|
||||
AscendC::GetRmsNormMaxMinTmpSize(srcShape, sizeof(uint16_t), maxValue, minValue, false);
|
||||
|
||||
if (calcBytes < minValue) {
|
||||
rmsnormtiling.calcBytes = minValue;
|
||||
} else if (calcBytes > maxValue) {
|
||||
rmsnormtiling.calcBytes = maxValue;
|
||||
} else {
|
||||
rmsnormtiling.calcBytes = calcBytes;
|
||||
}
|
||||
|
||||
optiling::RmsNormTiling tilingdata;
|
||||
AscendC::GetRmsNormTilingInfo(srcShape, oriSrcShape, rmsnormtiling.calcBytes, sizeof(uint16_t), tilingdata, false);
|
||||
size_t tilingSize = tilingdata.GetDataSize();
|
||||
tilingdata.SaveToBuffer(&rmsnormtiling.tiling, tilingSize);
|
||||
rmsnormtiling.epsilon = ep;
|
||||
rmsnormtiling.loopCount = loopCount;
|
||||
}
|
||||
|
||||
static inline void GetQuantTilingData(QuantInfo &quantInfo)
|
||||
{
|
||||
quantInfo.dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
quantInfo.dequantGroupSize = -1;
|
||||
quantInfo.quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
quantInfo.quantGroupSize = -1;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(
|
||||
gert::TilingContext *context, const char *nodeName, MatmulAllreduceAddRmsnormTilingData &tilingData)
|
||||
|
||||
{
|
||||
CommTilingData &commTilingData = tilingData.matmulAllreduceAddRmsnormInfo.commTilingData;
|
||||
PPTilingData &ppTilingData = tilingData.matmulAllreduceAddRmsnormInfo.ppTilingData;
|
||||
RmsNormTilingData &rmsnormTilingData = tilingData.matmulAllreduceAddRmsnormInfo.rmsnormTilingData;
|
||||
QuantInfo &quantInfo = tilingData.matmulAllreduceAddRmsnormInfo.quantInfo;
|
||||
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto RankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_SIZE_INDEX);
|
||||
auto RankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_RANK_ID_INDEX);
|
||||
|
||||
bool isTransB = *(attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B_INDEX));
|
||||
|
||||
ppTilingData.isTransA = false;
|
||||
ppTilingData.isTransB = isTransB;
|
||||
ppTilingData.isGatherAddOut = *(attrs->GetAttrPointer<bool>(ATTR_IS_GATHER_ADD_OUT_INDEX));
|
||||
|
||||
auto &opShape = ppTilingData.opShape;
|
||||
auto &tensor0Shape = context->GetInputTensor(0)->GetOriginShape();
|
||||
uint32_t dimNum = tensor0Shape.GetDimNum();
|
||||
int64_t bs;
|
||||
int64_t rankM;
|
||||
int64_t rankK;
|
||||
|
||||
if (dimNum == DIM_NUM_THREE) {
|
||||
bs = tensor0Shape.GetDim(DIM_INDEX_ZERO);
|
||||
rankM = tensor0Shape.GetDim(DIM_INDEX_ONE);
|
||||
rankK = tensor0Shape.GetDim(DIM_INDEX_TWO);
|
||||
} else if (dimNum == DIM_NUM_TWO) {
|
||||
bs = BATCH_SIZE_ONE;
|
||||
rankM = tensor0Shape.GetDim(DIM_INDEX_ZERO);
|
||||
rankK = tensor0Shape.GetDim(DIM_INDEX_ONE);
|
||||
} else {
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_E(nodeName, "Tiling input dim error.");
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t rankN = isTransB ?
|
||||
context->GetInputTensor(1)->GetOriginShape().GetDim(DIM_INDEX_ZERO) :
|
||||
context->GetInputTensor(1)->GetOriginShape().GetDim(DIM_INDEX_ONE);
|
||||
|
||||
opShape.batchSize = BATCH_SIZE_ONE;
|
||||
opShape.m = bs * rankM;
|
||||
opShape.n = rankN;
|
||||
opShape.k = rankK;
|
||||
|
||||
commTilingData.rankSize = static_cast<int32_t>(*RankSizePtr);
|
||||
commTilingData.rank = static_cast<int32_t>(*RankIdPtr);
|
||||
if (commTilingData.rankSize == RANKSIZE_EIGHT) {
|
||||
AllReduceEightRankFP16GetDefaultTiling(context, ppTilingData, commTilingData);
|
||||
} else {
|
||||
GetDefaultTiling(context, ppTilingData, commTilingData);
|
||||
}
|
||||
|
||||
uint32_t calcBytes = 0;
|
||||
uint32_t sLength = 1;
|
||||
std::vector<int64_t> shapeVec = {1, 1, rankN};
|
||||
std::vector<int64_t> oriShapeVec = shapeVec;
|
||||
auto EpsilonPtr = attrs->GetAttrPointer<float>(ATTR_EPSILON_INDEX);
|
||||
float epsilon = static_cast<float>(*EpsilonPtr);
|
||||
GetRmsnormTilingData(
|
||||
rmsnormTilingData, shapeVec, oriShapeVec, calcBytes, commTilingData.rankSize * sLength * rankN, epsilon);
|
||||
GetQuantTilingData(quantInfo);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
bool IsMatrixAligned(const int64_t &m, const int64_t &n, const bool &transpose, int nElemAlign)
|
||||
{
|
||||
return (transpose ? m : n) % nElemAlign == 0;
|
||||
}
|
||||
|
||||
int64_t GetAlignedMatrixSize(
|
||||
const int64_t &batchSize, const int64_t &m, const int64_t &n, const bool &transpose, int nElemAlign)
|
||||
{
|
||||
int64_t nRow = transpose ? n : m;
|
||||
int64_t nCol = transpose ? m : n;
|
||||
int64_t nColAlign = (nCol + nElemAlign - 1) / nElemAlign * nElemAlign;
|
||||
return batchSize * nRow * nColAlign;
|
||||
}
|
||||
|
||||
WorkspaceDetail GetWorkspaceDetail(CoCDataTypeDesc dataType, const MatMulInfo &mmInfo, const QuantInfo &quantInfo)
|
||||
{
|
||||
WorkspaceDetail workspaceDetail;
|
||||
|
||||
int32_t eleSize = COC_TYPE2ELE_SIZE.at(dataType);
|
||||
int32_t nElemAlign = ALIGN_BYTES / eleSize;
|
||||
|
||||
bool hasQuant = quantInfo.quantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
if (hasQuant || (!IsMatrixAligned(mmInfo.m, mmInfo.k, mmInfo.transA, nElemAlign) && mmInfo.m != 1)) {
|
||||
workspaceDetail.matrixActivationSize =
|
||||
GetAlignedMatrixSize(mmInfo.batchSize, mmInfo.m, mmInfo.k, mmInfo.transA, nElemAlign) * eleSize;
|
||||
}
|
||||
|
||||
bool hasDequant = quantInfo.dequantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
if ((hasDequant && !mmInfo.isInt8) || !IsMatrixAligned(mmInfo.k, mmInfo.n, mmInfo.transB, nElemAlign)) {
|
||||
workspaceDetail.matrixWeightSize =
|
||||
GetAlignedMatrixSize(mmInfo.batchSize, mmInfo.k, mmInfo.n, mmInfo.transB, nElemAlign) * eleSize;
|
||||
}
|
||||
|
||||
bool hasAccum = dataType == CoCDataTypeDesc::INT8INT8_INT32_BF16;
|
||||
if (hasAccum) {
|
||||
workspaceDetail.matrixIntermediateSize =
|
||||
static_cast<int64_t>(mmInfo.batchSize) * mmInfo.m * mmInfo.n * sizeof(int32_t);
|
||||
}
|
||||
|
||||
if (mmInfo.isInt8) {
|
||||
workspaceDetail.formatDequantParamSize =
|
||||
mmInfo.k > mmInfo.n ? mmInfo.k * sizeof(float) : mmInfo.n * sizeof(float);
|
||||
}
|
||||
return workspaceDetail;
|
||||
}
|
||||
|
||||
void GetMmInfo(gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling, MatMulInfo *mmInfo)
|
||||
{
|
||||
PPTilingData tempPPTilingData = tiling->matmulAllreduceAddRmsnormInfo.ppTilingData;
|
||||
mmInfo->batchSize = tempPPTilingData.opShape.batchSize;
|
||||
mmInfo->m = tempPPTilingData.opShape.m;
|
||||
mmInfo->n = tempPPTilingData.opShape.n;
|
||||
mmInfo->k = tempPPTilingData.opShape.k;
|
||||
auto attrs = context->GetAttrs();
|
||||
mmInfo->transA = false;
|
||||
mmInfo->transB = *(attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B_INDEX));
|
||||
mmInfo->withBias = false;
|
||||
mmInfo->weightNz = false;
|
||||
mmInfo->isInt8 = context->GetInputTensor(0)->GetDataType() == ge::DT_INT8;
|
||||
}
|
||||
|
||||
size_t GetUserWorkspaceSize(gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling)
|
||||
{
|
||||
MatMulInfo mmInfo;
|
||||
GetMmInfo(context, tiling, &mmInfo);
|
||||
QuantInfo quantInfo = tiling->matmulAllreduceAddRmsnormInfo.quantInfo;
|
||||
return GetWorkspaceDetail(FP16FP16_FP32_FP16, mmInfo, quantInfo).GetSize();
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName)
|
||||
{
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
size_t systemWorkspaceSize = static_cast<size_t>(ascendcPlatform.GetLibApiWorkSpaceSize());
|
||||
MatmulAllreduceAddRmsnormTilingData *tilingData = context->GetTilingData<MatmulAllreduceAddRmsnormTilingData>();
|
||||
size_t userWorkspaceSize = GetUserWorkspaceSize(context, tilingData);
|
||||
workSpaces[0] = userWorkspaceSize + systemWorkspaceSize;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(
|
||||
const gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling, const std::string groupTp)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
uint32_t opType = OP_TYPE_ALL_TO_ALL;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupTp, opType, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
|
||||
}
|
||||
|
||||
static ge::graphStatus MatmulAllreduceAddRmsnormTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
MatmulAllreduceAddRmsnormTilingData *tilingData = context->GetTilingData<MatmulAllreduceAddRmsnormTilingData>();
|
||||
OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, "hcomms");
|
||||
|
||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||
OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||
uint32_t aicNum_ = ascendcPlatform.GetCoreNumAic();
|
||||
context->SetBlockDim(aicNum_);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus MatmulAllreduceAddRmsnormTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = MatmulAllreduceAddRmsnormTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct MatmulAllreduceAddRmsnormCompileInfo {};
|
||||
ge::graphStatus TilingParseForMatmulAllreduceAddRmsnorm(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(MatmulAllreduceAddRmsnorm)
|
||||
.Tiling(MatmulAllreduceAddRmsnormTilingFunc)
|
||||
.TilingParse<MatmulAllreduceAddRmsnormCompileInfo>(TilingParseForMatmulAllreduceAddRmsnorm);
|
||||
@@ -0,0 +1,79 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_WORKSPACE_H
|
||||
#define MATMUL_ALLREDUCE_ADD_RMSNORM_WORKSPACE_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#pragma once
|
||||
const constexpr uint32_t ALIGN_BYTES = 512;
|
||||
const constexpr int32_t INT8_ELE_SIZE = 1;
|
||||
const constexpr int32_t FP_BF_16_ELE_SIZE = 2;
|
||||
|
||||
enum CoCDataTypeDesc : int {
|
||||
COC_DATA_TYPE_UNDEFINED = -1,
|
||||
FP16FP16_FP32_FP16 = 0,
|
||||
BF16BF16_FP32_BF16 = 1,
|
||||
INT8INT8_INT32_FP16 = 2,
|
||||
INT8INT8_INT32_BF16 = 3,
|
||||
FP16INT8_INT32_FP16 = 4,
|
||||
BF16INT8_INT32_BF16 = 5,
|
||||
FP16INT8_FP32_FP16 = 6,
|
||||
BF16INT8_FP32_BF16 = 7,
|
||||
FP16INT4_FP32_FP16 = 8,
|
||||
BF16INT4_FP32_BF16 = 9,
|
||||
COC_DATA_TYPE_DESC_MAX = 10,
|
||||
};
|
||||
|
||||
const std::map<CoCDataTypeDesc, int32_t> COC_TYPE2ELE_SIZE = {
|
||||
{FP16FP16_FP32_FP16, FP_BF_16_ELE_SIZE},
|
||||
{BF16BF16_FP32_BF16, FP_BF_16_ELE_SIZE},
|
||||
{INT8INT8_INT32_FP16, INT8_ELE_SIZE},
|
||||
{INT8INT8_INT32_BF16, INT8_ELE_SIZE},
|
||||
{FP16INT8_INT32_FP16, INT8_ELE_SIZE},
|
||||
{BF16INT8_INT32_BF16, INT8_ELE_SIZE},
|
||||
{FP16INT8_FP32_FP16, FP_BF_16_ELE_SIZE},
|
||||
{BF16INT8_FP32_BF16, FP_BF_16_ELE_SIZE},
|
||||
{FP16INT4_FP32_FP16, FP_BF_16_ELE_SIZE},
|
||||
{BF16INT4_FP32_BF16, FP_BF_16_ELE_SIZE}
|
||||
};
|
||||
|
||||
struct MatMulInfo {
|
||||
int64_t batchSize = 1;
|
||||
int64_t m = -1;
|
||||
int64_t n = -1;
|
||||
int64_t k = -1;
|
||||
bool transA = false;
|
||||
bool transB = false;
|
||||
bool withBias = false;
|
||||
bool isInt8 = false;
|
||||
bool weightNz = false;
|
||||
};
|
||||
|
||||
struct WorkspaceDetail {
|
||||
int64_t matrixActivationSize{0};
|
||||
int64_t matrixWeightSize{0};
|
||||
int64_t matrixIntermediateSize{0};
|
||||
int64_t formatDequantParamSize{0};
|
||||
|
||||
int64_t GetSize() const
|
||||
{
|
||||
return matrixActivationSize + matrixWeightSize + matrixIntermediateSize + formatDequantParamSize;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "lib/matmul_intf.h"
|
||||
#include <kernel_operator.h>
|
||||
#include "matmul_allreduce_add_rmsnorm_aic_kernel.h"
|
||||
#include "matmul_allreduce_add_rmsnorm_aiv_kernel.h"
|
||||
|
||||
extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR residual,
|
||||
GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(MatmulAllreduceAddRmsnormTilingData);
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
|
||||
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
auto tilingData = (__gm__ MatmulAllreduceAddRmsnormTilingData*)tiling;
|
||||
__gm__ void* mc2InitTiling = (__gm__ void*)(&(tilingData->mc2InitTiling));
|
||||
__gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling));
|
||||
auto contextGM0 = AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
|
||||
|
||||
if ASCEND_IS_AIC {
|
||||
MatmulAllreduceAddRmsnormAicKernel<DTYPE_X1, DTYPE_Y> op;
|
||||
op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_);
|
||||
op.Process();
|
||||
return;
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
MatmulAllreduceAddRmsnormAivKernel<DTYPE_X1, DTYPE_Y> op;
|
||||
|
||||
op.Init(x1, x2, residual, gamma, y, add_out, workspace, &tiling_data, hccl_);
|
||||
op.Process(&tiling_data);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H
|
||||
#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H
|
||||
|
||||
#define ASCENDC_CUBE_ONLY
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/arch.hpp"
|
||||
#include "catlass/gemm/block/block_mmad.hpp"
|
||||
#include "catlass/gemm/block/block_swizzle.hpp"
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
#include "catlass/gemm/kernel/basic_matmul.hpp"
|
||||
#include "catlass/gemm/gemm_type.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
|
||||
#include "matmul_allreduce_add_rmsnorm_utils.h"
|
||||
#include "matmul_allreduce_add_rmsnorm_tiling.h"
|
||||
|
||||
constexpr int32_t SCALE_L1_SIZE_A = 256 * 8;
|
||||
constexpr int32_t SCALE_L1_SIZE_B = 128 * 1024;
|
||||
constexpr int32_t CUBE_MATRIX_SIZE_B16 = 256; // 16 * 16
|
||||
constexpr int32_t CUBE_MATRIX_SIZE_B8 = 16 * 32; // 16 * 32
|
||||
constexpr int32_t SCALE_L1_SIZE = 256 * 8; // 2 KB
|
||||
constexpr int32_t BLOCK_SIZE_16 = 16;
|
||||
constexpr int32_t BLOCK_SIZE_32 = 32;
|
||||
constexpr int32_t DOUBLE_BUFFER_SIZE = 2;
|
||||
constexpr uint32_t MM_L1_TILE_SHAPE_M = 128;
|
||||
constexpr uint32_t MM_L1_TILE_SHAPE_N = 256;
|
||||
constexpr uint32_t MM_L1_TILE_SHAPE_K = 256;
|
||||
constexpr uint32_t MM_L0_TILE_SHAPE_M = MM_L1_TILE_SHAPE_M;
|
||||
constexpr uint32_t MM_L0_TILE_SHAPE_N = MM_L1_TILE_SHAPE_N;
|
||||
constexpr uint32_t MM_L0_TILE_SHAPE_K = 64;
|
||||
|
||||
using namespace Catlass;
|
||||
|
||||
template <typename T_INPUT>
|
||||
struct GetAccumType {
|
||||
using T = float;
|
||||
};
|
||||
|
||||
__aicore__ inline bool IsQuant(const QuantGranularity &granularity)
|
||||
{
|
||||
return (granularity > QuantGranularity::QUANT_GRANULARITY_UNDEFINED) &&
|
||||
(granularity < QuantGranularity::QUANT_GRANULARITY_MAX);
|
||||
}
|
||||
|
||||
template <typename MmadDtype, typename OutDtype>
|
||||
class MatmulAllreduceAddRmsnormAicKernel {
|
||||
using T_ACCUM = typename GetAccumType<MmadDtype>::T;
|
||||
public:
|
||||
int PIPE_DEPTH = 2;
|
||||
Arch::Resource<Arch::AtlasA2> resource;
|
||||
__aicore__ inline MatmulAllreduceAddRmsnormAicKernel<MmadDtype, OutDtype>() { }
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y,
|
||||
GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData* tilingData,
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> &hccl_)
|
||||
{
|
||||
this->hccl_ = hccl_;
|
||||
this->gm_c = reinterpret_cast<__gm__ OutDtype *>(y);
|
||||
|
||||
this->gm_dequant_scale = nullptr;
|
||||
this->has_offset = false;
|
||||
|
||||
auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData;
|
||||
auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData;
|
||||
auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo;
|
||||
|
||||
this->batch_size = ppTilingData->opShape.batchSize;
|
||||
this->m = ppTilingData->opShape.m;
|
||||
this->k = ppTilingData->opShape.k;
|
||||
this->n = ppTilingData->opShape.n;
|
||||
this->weight_nz = false;
|
||||
|
||||
this->is_int8 = false;
|
||||
this->cube_matrix_size = this->is_int8 ? CUBE_MATRIX_SIZE_B8 : CUBE_MATRIX_SIZE_B16;
|
||||
|
||||
this->m_align = Block512B<MmadDtype>::AlignUp(m);
|
||||
this->k_align = Block512B<MmadDtype>::AlignUp(k);
|
||||
this->n_align = Block512B<MmadDtype>::AlignUp(n);
|
||||
|
||||
this->m0 = ppTilingData->m0;
|
||||
this->k0 = ppTilingData->k0;
|
||||
this->n0 = ppTilingData->n0;
|
||||
|
||||
int32_t tiling_key = ppTilingData->tilingKey;
|
||||
this->trans_a = ppTilingData->isTransA;
|
||||
this->trans_b = ppTilingData->isTransB;
|
||||
|
||||
int32_t aligned_a;
|
||||
int32_t aligned_b;
|
||||
this->dequant_granularity = quantInfo->dequantGranularity;
|
||||
AlignJudge(this->trans_a, this->trans_b, this->m, this->k, this->n,
|
||||
this->m_align, this->k_align, this->n_align, aligned_a, aligned_b);
|
||||
this->aligned_a = aligned_a;
|
||||
this->aligned_b = aligned_b;
|
||||
if (weight_nz) {
|
||||
this->k_align16 = Block32B<MmadDtype>::AlignUp(k);
|
||||
this->n_align16 = Block32B<MmadDtype>::AlignUp(n);
|
||||
}
|
||||
bool has_a_align = IsQuant(quantInfo->quantGranularity) || aligned_a;
|
||||
bool has_b_align = IsQuant(this->dequant_granularity) && !this->is_int8 || aligned_b;
|
||||
bool has_accum = IsQuant(this->dequant_granularity) &&
|
||||
this->is_int8 && std::is_same<OutDtype, bfloat16_t>::value;
|
||||
bool has_format_dequant_offset =
|
||||
(this->dequant_granularity == QuantGranularity::PER_TENSOR) && this->is_int8 && this->has_offset;
|
||||
auto workspace_info = GetWorkspaceInfo(workspace, this->batch_size, this->m, this->k, this->n,
|
||||
this->m_align, this->k_align, this->n_align, this->trans_a, this->trans_b,
|
||||
sizeof(MmadDtype), has_a_align, has_b_align, has_accum, has_format_dequant_offset);
|
||||
this->gm_a_src = reinterpret_cast<__gm__ MmadDtype *>(x1);
|
||||
this->gm_b_src = reinterpret_cast<__gm__ MmadDtype *>(x2);
|
||||
this->gm_format_dequant_offset = reinterpret_cast<__gm__ int32_t *>(has_format_dequant_offset ?
|
||||
workspace_info.gm_dequant_param : nullptr);
|
||||
this->gm_workspace_src = workspace;
|
||||
this->block_size = BLOCK_SIZE_32 / sizeof(MmadDtype);
|
||||
|
||||
int32_t a_l1_size = this->m0 * this->k0 * sizeof(MmadDtype);
|
||||
int32_t a_l1_size_round = AscendC::DivCeil(a_l1_size, 512) * 512;
|
||||
int32_t b_l1_size = this->n0 * this->k0 * sizeof(MmadDtype);
|
||||
int32_t b_l1_size_round = AscendC::DivCeil(b_l1_size, 512) * 512;
|
||||
this->l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t)(this->is_int8 ? SCALE_L1_SIZE : 0));
|
||||
this->l1_base_b =
|
||||
reinterpret_cast<__cbuf__ MmadDtype *>(a_l1_size_round * (this->is_int8 ? DOUBLE_BUFFER_SIZE : 1) +
|
||||
(uintptr_t) this->l1_base_a);
|
||||
|
||||
this->core_num = get_block_num();
|
||||
this->core_idx = get_block_idx();
|
||||
|
||||
this->m_loop = ppTilingData->mLoop;
|
||||
this->k_loop = ppTilingData->kLoop;
|
||||
this->n_loop = ppTilingData->nLoop;
|
||||
this->core_loop = ppTilingData->coreLoop;
|
||||
this->swizzl_count = ppTilingData->swizzlCount;
|
||||
this->swizzl_direct = ppTilingData->swizzlDirect;
|
||||
this->is_91093 = commTilingData->is91093;
|
||||
this->ping_flag = 1;
|
||||
this->rank = hccl_.GetRankId();
|
||||
this->rank_size = hccl_.GetRankDim();
|
||||
this->withSerialMode = commTilingData->withSerialMode;
|
||||
|
||||
this->gm_peer_mem = (__gm__ OutDtype *)hccl_.GetWindowsInAddr(this->rank);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoveL0CToGM(__gm__ OutDtype *gm_dst, int64_t offset_c,
|
||||
int32_t m_actual, int32_t n_actual, int32_t src_stride, int32_t dst_stride) {
|
||||
if constexpr (std::is_same<OutDtype, __bf16>::value) {
|
||||
copy_matrix_cc_to_gm(
|
||||
gm_dst + offset_c,
|
||||
l0c_buf,
|
||||
0,
|
||||
n_actual,
|
||||
m_actual,
|
||||
dst_stride,
|
||||
src_stride,
|
||||
0,
|
||||
F322BF16,
|
||||
0,
|
||||
false,
|
||||
true
|
||||
);
|
||||
} else {
|
||||
copy_matrix_cc_to_gm(
|
||||
gm_dst + offset_c,
|
||||
l0c_buf,
|
||||
0,
|
||||
n_actual,
|
||||
m_actual,
|
||||
dst_stride,
|
||||
src_stride,
|
||||
0,
|
||||
F322F16,
|
||||
0,
|
||||
false,
|
||||
true
|
||||
);
|
||||
}
|
||||
SetFlag<HardEvent::FIX_M>(EVENT_ID0);
|
||||
}
|
||||
|
||||
__aicore__ inline void InitFlags()
|
||||
{
|
||||
WaitEvent(AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID);
|
||||
}
|
||||
|
||||
__aicore__ inline void Endflags()
|
||||
{
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
// AIC matmul func, waits for AIV to complete [AllReduce & Add & RMSNorm].
|
||||
InitFlags();
|
||||
uint32_t m = this->m;
|
||||
uint32_t k = this->k;
|
||||
uint32_t n = this->n;
|
||||
gmB.SetGlobalBuffer(gm_b_src, k * n);
|
||||
|
||||
using LayoutA = layout::RowMajor;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using LayoutC = layout::RowMajor;
|
||||
LayoutB layoutB {(layout::ColumnMajor::Index)k, (layout::ColumnMajor::Index)n};
|
||||
|
||||
using L1TileShape = GemmShape<MM_L1_TILE_SHAPE_M, MM_L1_TILE_SHAPE_N, MM_L1_TILE_SHAPE_K>;
|
||||
using L0TileShape = GemmShape<MM_L0_TILE_SHAPE_M, MM_L0_TILE_SHAPE_N, MM_L0_TILE_SHAPE_K>;
|
||||
using AType = Gemm::GemmType<MmadDtype, LayoutA>;
|
||||
using BType = Gemm::GemmType<MmadDtype, LayoutB>;
|
||||
using CType = AType;
|
||||
constexpr bool ENABLE_UNIT_FLAG = true;
|
||||
using MmadDispatchPolicy = Gemm::MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG>;
|
||||
using BlockMmad = Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
||||
GemmCoord blockShape = L1TileShape::ToCoord();
|
||||
|
||||
BlockMmad blockMmad(resource);
|
||||
int mPerSplit = this->m0 * this->swizzl_count;
|
||||
int mAvg = mPerSplit;
|
||||
int splitM = AscendC::DivCeil(m, mPerSplit);
|
||||
int flag_idx = 0;
|
||||
icache_preload(8); // 8 corresponding to 16k
|
||||
for (int splitIndex = 0; splitIndex < splitM; ++splitIndex) {
|
||||
uint32_t mStart = splitIndex * mAvg;
|
||||
uint32_t mActual = mAvg > (m - mStart) ? m - mStart:mAvg;
|
||||
flag_idx = splitIndex % PIPE_DEPTH;
|
||||
if (splitIndex >= PIPE_DEPTH) {
|
||||
WaitEvent(flag_idx);
|
||||
}
|
||||
|
||||
__gm__ MmadDtype *gm_a_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_a_src) + mStart * k;
|
||||
__gm__ MmadDtype *gm_c_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + mStart * n;
|
||||
gmA.SetGlobalBuffer(gm_a_src_tmp, mActual*k);
|
||||
gmC.SetGlobalBuffer(gm_c_src_tmp, mActual*n);
|
||||
|
||||
GemmCoord splitShape{mActual, n, k};
|
||||
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 1>; // SwizzleOffset=3
|
||||
BlockScheduler splitScheduler(splitShape, blockShape.GetCoordMN());
|
||||
uint32_t coreLoops = splitScheduler.GetCoreLoops();
|
||||
|
||||
LayoutA layoutA{mActual, k};
|
||||
LayoutC layoutC{mActual, n};
|
||||
|
||||
for (uint32_t loopIdx = core_idx; loopIdx < coreLoops; loopIdx += core_num) {
|
||||
GemmCoord blockCoord = splitScheduler.GetBlockCoord(loopIdx);
|
||||
GemmCoord actualBlockShape = splitScheduler.GetActualBlockShape(blockCoord);
|
||||
GemmCoord offsetCoord = blockCoord * blockShape;
|
||||
|
||||
MatrixCoord offsetA = offsetCoord.GetCoordMK();
|
||||
MatrixCoord offsetB = offsetCoord.GetCoordKN();
|
||||
MatrixCoord offsetC = offsetCoord.GetCoordMN();
|
||||
|
||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetB = layoutB.GetOffset(offsetB);
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
|
||||
blockMmad (gmA[gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], layoutC, actualBlockShape);
|
||||
}
|
||||
|
||||
FFTSCrossCoreSync<PIPE_FIX>(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx);
|
||||
}
|
||||
|
||||
Endflags();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::GlobalTensor<MmadDtype> gmA;
|
||||
AscendC::GlobalTensor<MmadDtype> gmB;
|
||||
AscendC::GlobalTensor<MmadDtype> gmC;
|
||||
__gm__ MmadDtype *gm_a_src{nullptr};
|
||||
__gm__ MmadDtype *gm_b_src{nullptr};
|
||||
|
||||
__gm__ OutDtype *gm_c{nullptr};
|
||||
__gm__ OutDtype *gm_peer_mem{nullptr};
|
||||
__gm__ int64_t *gm_dequant_scale{nullptr};
|
||||
__gm__ int32_t *gm_format_dequant_offset{nullptr};
|
||||
__gm__ int32_t *gm_accum{nullptr};
|
||||
__gm__ uint8_t *gm_workspace_src;
|
||||
|
||||
__cbuf__ MmadDtype *l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_A);
|
||||
__cbuf__ MmadDtype *l1_base_b = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_B);
|
||||
|
||||
__ca__ MmadDtype *l0a_base = reinterpret_cast<__ca__ MmadDtype *>((uintptr_t) 0);
|
||||
__cb__ MmadDtype *l0b_base = reinterpret_cast<__cb__ MmadDtype *>((uintptr_t) 0);
|
||||
|
||||
__cc__ T_ACCUM *l0c_buf = reinterpret_cast<__cc__ T_ACCUM *>((uintptr_t) 0);
|
||||
|
||||
__cbuf__ int64_t *scale_l1 = reinterpret_cast<__cbuf__ int64_t *>((uintptr_t) 0);
|
||||
__fbuf__ int64_t *scale_FB = (__fbuf__ int64_t *)(0);
|
||||
|
||||
__cbuf__ int32_t *bias_l1 = reinterpret_cast<__cbuf__ int32_t *>((uintptr_t)0);
|
||||
uint16_t bias_bt = 0;
|
||||
bool has_offset{false};
|
||||
|
||||
int32_t core_num;
|
||||
|
||||
int32_t batch_size;
|
||||
int32_t m;
|
||||
int32_t k;
|
||||
int32_t n;
|
||||
int32_t m_align;
|
||||
int32_t k_align;
|
||||
int32_t n_align;
|
||||
int32_t k_align16;
|
||||
int32_t n_align16;
|
||||
int32_t m0;
|
||||
int32_t k0;
|
||||
int32_t n0;
|
||||
|
||||
int32_t m_loop;
|
||||
int32_t n_loop;
|
||||
int32_t k_loop;
|
||||
int32_t core_loop;
|
||||
int32_t core_idx;
|
||||
int32_t ping_flag;
|
||||
int32_t block_size;
|
||||
int32_t cube_matrix_size;
|
||||
|
||||
int32_t aligned_a;
|
||||
int32_t aligned_b;
|
||||
|
||||
int32_t swizzl_count;
|
||||
int32_t swizzl_direct;
|
||||
|
||||
int32_t rank;
|
||||
int32_t rank_size;
|
||||
|
||||
int32_t withSerialMode;
|
||||
|
||||
int32_t ag_dim;
|
||||
int32_t rs_dim;
|
||||
bool inner_dim_is_Ag{false};
|
||||
int32_t ag_rank_idx;
|
||||
int32_t rs_rank_idx;
|
||||
bool weight_nz{false};
|
||||
|
||||
bool is_91093{false};
|
||||
QuantGranularity dequant_granularity;
|
||||
|
||||
bool is_int8;
|
||||
bool trans_a;
|
||||
bool trans_b;
|
||||
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
};
|
||||
|
||||
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H
|
||||
@@ -0,0 +1,702 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H
|
||||
#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "matmul_allreduce_add_rmsnorm_tiling.h"
|
||||
#include "matmul_allreduce_add_rmsnorm_utils.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr int32_t DIFUSION_ADD_LEN = 512;
|
||||
constexpr int32_t TQUE_DEPTH = 1;
|
||||
constexpr uint32_t TBUF_POOL_MAX_BUFID_SIZE = 8;
|
||||
enum CrossRankSyncFlagEnum {
|
||||
FLAG_ZERO_IDX,
|
||||
FLAG_ONE_IDX,
|
||||
FLAG_TWO_IDX,
|
||||
FLAG_ADD_IDX,
|
||||
FLAG_FOUR_IDX,
|
||||
FLAG_GATHER_ADD_OUT_STEP1,
|
||||
FLAG_GATHER_ADD_OUT_STEP2,
|
||||
FLAG_NUM
|
||||
};
|
||||
constexpr int32_t FLAG_VALUE = 1;
|
||||
constexpr int32_t NUM_PER_REP_FP32 = 64;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ void CopyUbufToGmAlignB16(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint32_t lenBurst,
|
||||
uint16_t srcSTride, uint16_t dstStride)
|
||||
{
|
||||
DataCopyExtParams dataCopyParams(nBurst,
|
||||
lenBurst,
|
||||
srcSTride,
|
||||
dstStride,
|
||||
0);
|
||||
LocalTensor<uint8_t> ubTensor;
|
||||
TBuffAddr ubAddr;
|
||||
ubAddr.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ubAddr.bufferAddr = reinterpret_cast<uint64_t>(src);
|
||||
ubTensor.SetAddr(ubAddr);
|
||||
GlobalTensor<uint8_t> gmTensor;
|
||||
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(dst));
|
||||
DataCopyPad(gmTensor, ubTensor, dataCopyParams);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ void CopyGmToUbufAlignB16(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst,
|
||||
uint16_t srcSTride, uint16_t dstStride)
|
||||
{
|
||||
DataCopyExtParams dataCopyParams(nBurst,
|
||||
lenBurst,
|
||||
srcSTride,
|
||||
dstStride,
|
||||
0);
|
||||
LocalTensor<uint8_t> ubTensor;
|
||||
TBuffAddr ubAddr;
|
||||
ubAddr.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ubAddr.bufferAddr = reinterpret_cast<uint64_t>(dst);
|
||||
ubTensor.SetAddr(ubAddr);
|
||||
GlobalTensor<uint8_t> gmTensor;
|
||||
gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(src));
|
||||
DataCopyPadExtParams<uint8_t> padParams;
|
||||
DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams);
|
||||
}
|
||||
|
||||
template <typename MmadDtype, typename OutDtype>
|
||||
class MatmulAllreduceAddRmsnormAivKernel {
|
||||
|
||||
public:
|
||||
__aicore__ inline MatmulAllreduceAddRmsnormAivKernel<MmadDtype, OutDtype>() { }
|
||||
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out,
|
||||
GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData *tilingData,
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> &hccl_)
|
||||
{
|
||||
this->hccl_ = hccl_;
|
||||
is_deterministic = false;
|
||||
auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData;
|
||||
auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData;
|
||||
auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo;
|
||||
|
||||
gm_out = reinterpret_cast<__gm__ MmadDtype *>(y);
|
||||
gm_add_input = reinterpret_cast<__gm__ MmadDtype *>(residual);
|
||||
gm_add_output = reinterpret_cast<__gm__ MmadDtype *>(add_out);
|
||||
gm_gamma = reinterpret_cast<__gm__ MmadDtype *>(gamma);
|
||||
|
||||
batch_size = ppTilingData->opShape.batchSize;
|
||||
m = ppTilingData->opShape.m;
|
||||
k = ppTilingData->opShape.k;
|
||||
n = ppTilingData->opShape.n;
|
||||
|
||||
m0 = ppTilingData->m0;
|
||||
k0 = ppTilingData->k0;
|
||||
n0 = ppTilingData->n0;
|
||||
|
||||
m_loop = ppTilingData->mLoop;
|
||||
k_loop = ppTilingData->kLoop;
|
||||
n_loop = ppTilingData->nLoop;
|
||||
|
||||
core_loop = ppTilingData->coreLoop;
|
||||
swizzl_count = ppTilingData->swizzlCount;
|
||||
tiling_key = ppTilingData->tilingKey;
|
||||
rank = hccl_.GetRankId();
|
||||
rank_size = hccl_.GetRankDim();
|
||||
|
||||
max_ub_single_dma_size = commTilingData->ubMoveNum;
|
||||
withSerialMode = false;
|
||||
tag = commTilingData->tag;
|
||||
comm_npu_split = commTilingData->commNpuSplit;
|
||||
comm_data_split = commTilingData->commDataSplit;
|
||||
comm_direct = commTilingData->commDirect;
|
||||
is_91093 = false;
|
||||
core_count = comm_npu_split * comm_data_split;
|
||||
dequant_granularity = static_cast<QuantGranularity>(quantInfo->dequantGranularity);
|
||||
dequant_group_size = quantInfo->dequantGroupSize;
|
||||
quant_granularity = static_cast<QuantGranularity>(quantInfo->quantGranularity);
|
||||
quant_group_size = quantInfo->quantGroupSize;
|
||||
epsilon = tilingData->matmulAllreduceAddRmsnormInfo.rmsnormTilingData.epsilon;
|
||||
is_gather_add_out = tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData.isGatherAddOut;
|
||||
|
||||
swizzl_direct = (tiling_key & SWIZZL_MASK) ? true : false;
|
||||
trans_a = ppTilingData->isTransA;
|
||||
trans_b = ppTilingData->isTransB;
|
||||
is_int8 = false;
|
||||
ag_dim = 0;
|
||||
rs_dim = 0;
|
||||
inner_dim_is_Ag = false;
|
||||
weight_nz = false;
|
||||
max_ub_ping_pong_size = max_ub_single_dma_size / 2; // 2 - double buffer
|
||||
|
||||
core_idx = get_block_idx();
|
||||
core_num = get_block_num();
|
||||
aiv_idx = get_subblockid();
|
||||
other_rank = (core_idx < rank_size) ? core_idx : -1;
|
||||
|
||||
// init ub usage
|
||||
pipe.InitBuffer(ctrlBuf, AscendC::ONE_BLK_SIZE);
|
||||
ub_ctrl_flag = reinterpret_cast<__ubuf__ int32_t *>(ctrlBuf.Get<int32_t>().GetPhyAddr());
|
||||
|
||||
pipe.InitBuffer(gammaBuf, n * sizeof(MmadDtype));
|
||||
|
||||
uint32_t step1_ub_usage = AscendC::AlignUp(
|
||||
n * sizeof(MmadDtype) +
|
||||
2 * (rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype)) +
|
||||
n * sizeof(MmadDtype) +
|
||||
n * sizeof(MmadDtype) +
|
||||
n * sizeof(float) +
|
||||
n * sizeof(float) +
|
||||
n * sizeof(float),
|
||||
AscendC::ONE_BLK_SIZE);
|
||||
|
||||
uint32_t step2_ub_usage = AscendC::AlignUp(
|
||||
max_ub_ping_pong_size * sizeof(MmadDtype),
|
||||
AscendC::ONE_BLK_SIZE) * 2;
|
||||
uint32_t max_step_ub_usage = max(step1_ub_usage, step2_ub_usage);
|
||||
|
||||
pipe.InitBufPool(step1BufPool, max_step_ub_usage);
|
||||
pipe.InitBufPool(step2BufPool, max_step_ub_usage, step1BufPool);
|
||||
|
||||
step1BufPool.InitBuffer(inQueueX, 1, n * sizeof(MmadDtype));
|
||||
step1BufPool.InitBuffer(inQueueY, 2, rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype));
|
||||
step1BufPool.InitBuffer(addOutQueue, 1, n * sizeof(MmadDtype));
|
||||
step1BufPool.InitBuffer(outQueue, 1, n * sizeof(MmadDtype));
|
||||
step1BufPool.InitBuffer(xFp32Buf, n * sizeof(float));
|
||||
step1BufPool.InitBuffer(sqxBuf, n * sizeof(float));
|
||||
step1BufPool.InitBuffer(reduceFp32Buf, n * sizeof(float));
|
||||
|
||||
step2BufPool.InitBuffer(allgatherBuf[0], max_ub_ping_pong_size * sizeof(MmadDtype));
|
||||
step2BufPool.InitBuffer(allgatherBuf[1], max_ub_ping_pong_size * sizeof(MmadDtype));
|
||||
|
||||
CopyInGamma();
|
||||
}
|
||||
|
||||
__aicore__ inline void Process(const MatmulAllreduceAddRmsnormTilingData *tilingData)
|
||||
{
|
||||
// AIV AllReduce & Add & RMSNorm func, waits for AIC to complete [Matmul].
|
||||
FFTSCrossCoreSync<PIPE_MTE3>(FFTS_SYNC_AICORE_GROUP_MODE, AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
ResetIpcFlags(FLAG_NUM);
|
||||
CrossRankSyncEx(FLAG_NUM);
|
||||
constexpr int32_t allreduce_used_core = 16;
|
||||
int32_t one_comm_count = swizzl_count;
|
||||
int32_t loop_num_per_comm = one_comm_count * n_loop;
|
||||
int32_t comm_count = DivCeil(core_loop, loop_num_per_comm);
|
||||
int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT;
|
||||
|
||||
for (int cal_idx = 0; cal_idx < comm_count; ++cal_idx) {
|
||||
uint64_t flag_idx = cal_idx % pipe_depth;
|
||||
int32_t m_total = (cal_idx == comm_count - 1) ?
|
||||
m - cal_idx * swizzl_count * m0 : swizzl_count * m0;
|
||||
int32_t m_per_rank = DivCeil(m_total, rank_size);
|
||||
int32_t loop_offset = cal_idx * swizzl_count * m0;
|
||||
|
||||
WaitEvent(flag_idx);
|
||||
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
|
||||
CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1);
|
||||
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
|
||||
|
||||
if (aiv_idx == 0 && core_idx < allreduce_used_core) {
|
||||
int32_t m_cur_rank = LimitRange(m_total - rank * m_per_rank, 0, m_per_rank);
|
||||
int32_t m_per_core = DivCeil(m_cur_rank, allreduce_used_core);
|
||||
int32_t m_cur_core = LimitRange(m_cur_rank - core_idx * m_per_core, 0, m_per_core);
|
||||
int32_t core_offset_m = loop_offset + rank * m_per_rank + core_idx * m_per_core;
|
||||
ParallelWithSplitStepOneAddNorm(core_offset_m * n, m_cur_core);
|
||||
}
|
||||
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
|
||||
CrossRankSyncV1(FLAG_ADD_IDX, cal_idx + 1);
|
||||
SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT);
|
||||
|
||||
{ // ParallelWithSplitStepTwo
|
||||
int32_t used_core_per_rank = allreduce_used_core / rank_size;
|
||||
int32_t sub_core_idx = core_idx % used_core_per_rank;
|
||||
int32_t gather_rank_id = core_idx / used_core_per_rank;
|
||||
int32_t m_in_rank = LimitRange(m_total - gather_rank_id * m_per_rank, 0, m_per_rank);
|
||||
int32_t m_per_core = DivCeil(m_in_rank, used_core_per_rank);
|
||||
int32_t m_cur_core = LimitRange(m_in_rank - sub_core_idx * m_per_core, 0, m_per_core);
|
||||
int32_t core_offset_m = loop_offset + gather_rank_id * m_per_rank + sub_core_idx * m_per_core;
|
||||
auto gm_share_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(gather_rank_id);
|
||||
|
||||
bool filter_core_cond = aiv_idx == 0 && core_idx < allreduce_used_core && m_cur_core > 0;
|
||||
if (filter_core_cond) {
|
||||
ParallelAllGather(gm_out, gm_share_buff, core_offset_m * n, m_cur_core * n);
|
||||
}
|
||||
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
CrossRankSyncV2(FLAG_TWO_IDX, cal_idx + 1);
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
|
||||
if (is_gather_add_out) {
|
||||
if (filter_core_cond && gather_rank_id == rank) {
|
||||
ParallelAllGather(gm_share_buff, gm_add_output, core_offset_m * n, m_cur_core * n);
|
||||
}
|
||||
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP1, cal_idx + 1);
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
|
||||
if (filter_core_cond && gather_rank_id != rank) {
|
||||
ParallelAllGather(gm_add_output, gm_share_buff, core_offset_m * n, m_cur_core * n);
|
||||
}
|
||||
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP2, cal_idx + 1);
|
||||
SetAndWaitAivSync(flag_idx);
|
||||
}
|
||||
}
|
||||
|
||||
if (cal_idx <= comm_count - pipe_depth) {
|
||||
SetAicSync(flag_idx);
|
||||
}
|
||||
}
|
||||
ResetIpcFlags(FLAG_NUM);
|
||||
if (aiv_idx == 0 && core_idx < rank_size) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(other_rank);
|
||||
CheckBuffFlag(ub_ctrl_flag, state_buff + FLAG_ZERO_IDX, 0);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ void SetBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
|
||||
{
|
||||
*ub_ctrl_flag = flag;
|
||||
SetFlag<HardEvent::S_MTE3>(EVENT_ID2);
|
||||
WaitFlag<HardEvent::S_MTE3>(EVENT_ID2);
|
||||
CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0);
|
||||
}
|
||||
|
||||
__aicore__ void SetBuffFlagByAdd(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
|
||||
{
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
*ub_ctrl_flag = flag;
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
SetAtomicAdd<int32_t>();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0);
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
SetAtomicNone();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
__aicore__ void CheckBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag)
|
||||
{
|
||||
SetFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
WaitFlag<HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
while (true) {
|
||||
CopyGmToUbufAlignB16(ub_ctrl_flag, buff, 1, sizeof(int32_t), 0, 0);
|
||||
SetFlag<HardEvent::MTE2_S>(EVENT_ID3);
|
||||
WaitFlag<HardEvent::MTE2_S>(EVENT_ID3);
|
||||
if (*ub_ctrl_flag == flag) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ void SetAicSync(uint64_t flag_idx)
|
||||
{
|
||||
FFTSCrossCoreSync<PIPE_MTE3>(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx);
|
||||
}
|
||||
|
||||
__aicore__ void ResetIpcFlags(int32_t num_flags)
|
||||
{
|
||||
for (int32_t idx = 0; idx <= num_flags; ++idx) {
|
||||
if (core_idx == 0 && aiv_idx == 0) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
|
||||
SetBuffFlag(ub_ctrl_flag, state_buff + idx, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ void CrossRankSyncV1(int32_t flag_idx, int32_t flag_data)
|
||||
{
|
||||
if (aiv_idx == 0 && core_idx == rank) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
|
||||
SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE);
|
||||
} else if (aiv_idx == 0 && core_idx < rank_size) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx);
|
||||
CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * flag_data);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ void CrossRankSyncV2(int32_t flag_idx, int32_t flag_data)
|
||||
{
|
||||
if (aiv_idx == 0 && core_idx < rank_size) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx);
|
||||
SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE);
|
||||
}
|
||||
if (aiv_idx == 0 && core_idx == rank) {
|
||||
__gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank);
|
||||
CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * rank_size * flag_data);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ void SetAndWaitAivSync(uint64_t flag_idx, int32_t pipe_depth = 2)
|
||||
{
|
||||
FFTSCrossCoreSync<PIPE_MTE3>(0, flag_idx + pipe_depth);
|
||||
WaitEvent(flag_idx + pipe_depth);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t GetGmU32(GM_ADDR gm_addr)
|
||||
{
|
||||
copy_gm_to_ubuf_align_b32(ub_ctrl_flag, gm_addr, 0, 1, sizeof(uint32_t), 0, 0, 0, 0);
|
||||
PipeSync<HardEvent::MTE2_S>();
|
||||
return *reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag);
|
||||
}
|
||||
|
||||
__aicore__ inline void SetGmU32(GM_ADDR gm_addr, uint32_t data)
|
||||
{
|
||||
*reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag) = data;
|
||||
PipeSync<HardEvent::S_MTE3>();
|
||||
copy_ubuf_to_gm_align_b32(gm_addr, ub_ctrl_flag, 0, 1, sizeof(uint32_t), 0, 0, 0, 0);
|
||||
}
|
||||
|
||||
__aicore__ inline void CrossRankSyncEx(uint32_t flag_idx)
|
||||
{
|
||||
AscendC::SyncAll<true>();
|
||||
__asm__ __volatile__("");
|
||||
if (aiv_idx == 0 && core_idx == 0) {
|
||||
auto flag_addr = (GM_ADDR)hccl_.GetWindowsOutAddr(0) + flag_idx * AscendC::ONE_BLK_SIZE;
|
||||
uint32_t old_flag_data = GetGmU32(flag_addr);
|
||||
__asm__ __volatile__("");
|
||||
SetAtomicAdd<int32_t>();
|
||||
SetGmU32(flag_addr, 1);
|
||||
PipeSync<HardEvent::MTE3_S>();
|
||||
SetAtomicNone();
|
||||
__asm__ __volatile__("");
|
||||
|
||||
uint32_t new_flag_data;
|
||||
do {
|
||||
new_flag_data = GetGmU32(flag_addr);
|
||||
__asm__ __volatile__("");
|
||||
} while (new_flag_data - old_flag_data < rank_size);
|
||||
__asm__ __volatile__("");
|
||||
SetAtomicAdd<int32_t>();
|
||||
SetGmU32(flag_addr, 1);
|
||||
PipeSync<HardEvent::MTE3_S>();
|
||||
SetAtomicNone();
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
AscendC::SyncAll<true>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T min(const T& a, const T& b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T max(const T& a, const T& b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T LimitRange(const T& val, const T& low, const T& high) {
|
||||
return min(max(val, low), high);
|
||||
}
|
||||
|
||||
template <AscendC::HardEvent EVENT>
|
||||
__aicore__ inline void PipeSync()
|
||||
{
|
||||
AscendC::TEventID event_id = static_cast<event_t>(GetTPipePtr()->FetchEventID(EVENT));
|
||||
AscendC::SetFlag<EVENT>(event_id);
|
||||
AscendC::WaitFlag<EVENT>(event_id);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGamma()
|
||||
{
|
||||
GlobalTensor<MmadDtype> gamma_global;
|
||||
gamma_global.SetGlobalBuffer((__gm__ MmadDtype *)gm_gamma, n);
|
||||
DataCopy(gammaBuf.Get<MmadDtype>(), gamma_global, n);
|
||||
PipeSync<HardEvent::MTE2_V>();
|
||||
}
|
||||
|
||||
__aicore__ void ParallelWithSplitStepOneAddNorm(uint32_t core_buf_offset, uint32_t m_cur_core)
|
||||
{
|
||||
if (m_cur_core <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(rank);
|
||||
|
||||
GlobalTensor<MmadDtype> x_global;
|
||||
GlobalTensor<MmadDtype> y_global;
|
||||
GlobalTensor<MmadDtype> out_global;
|
||||
GlobalTensor<MmadDtype> add_out_global;
|
||||
|
||||
x_global.SetGlobalBuffer(buff + core_buf_offset);
|
||||
out_global.SetGlobalBuffer(buff + core_buf_offset);
|
||||
add_out_global.SetGlobalBuffer(gm_add_output + core_buf_offset);
|
||||
|
||||
uint32_t add_count = DivCeil(n, DIFUSION_ADD_LEN);
|
||||
|
||||
LocalTensor<MmadDtype> x_local;
|
||||
LocalTensor<MmadDtype> y_local;
|
||||
|
||||
for (uint32_t i = 0; i < m_cur_core; i++) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
|
||||
x_local = inQueueX.AllocTensor<MmadDtype>();
|
||||
for (uint32_t j = 0; j < add_count; j++) {
|
||||
uint32_t add_offset = j * DIFUSION_ADD_LEN;
|
||||
uint32_t add_len = min<uint32_t>(n - add_offset, DIFUSION_ADD_LEN);
|
||||
|
||||
DataCopy(x_local[add_offset], x_global[i * n + add_offset], add_len);
|
||||
inQueueX.EnQue(x_local);
|
||||
|
||||
uint32_t iterate_end = (rank + 1) % rank_size;
|
||||
y_local = inQueueY.AllocTensor<MmadDtype>();
|
||||
for (uint32_t k = 0; k < rank_size; ++k) {
|
||||
uint32_t iterate_idx = iterate_end + k;
|
||||
if (iterate_idx >= rank_size) {
|
||||
iterate_idx -= rank_size;
|
||||
}
|
||||
|
||||
if (iterate_idx == rank) {
|
||||
y_global.SetGlobalBuffer(gm_add_input + core_buf_offset);
|
||||
} else {
|
||||
auto other_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(iterate_idx);
|
||||
y_global.SetGlobalBuffer(other_buff + core_buf_offset);
|
||||
}
|
||||
DataCopy(y_local[k * add_len], y_global[i * n + add_offset], add_len);
|
||||
}
|
||||
inQueueY.EnQue(y_local);
|
||||
x_local = inQueueX.DeQue<MmadDtype>();
|
||||
y_local = inQueueY.DeQue<MmadDtype>();
|
||||
|
||||
Cast(x_fp32[add_offset], x_local[add_offset], RoundMode::CAST_NONE, add_len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t k = 0; k < rank_size; ++k) {
|
||||
// use sqx as shared buf, required n >= add_len
|
||||
Cast(sqx, y_local[k * add_len], RoundMode::CAST_NONE, add_len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x_fp32[add_offset], x_fp32[add_offset], sqx, add_len);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
inQueueY.FreeTensor(y_local);
|
||||
}
|
||||
inQueueX.FreeTensor(x_local);
|
||||
|
||||
// copy add result out
|
||||
LocalTensor<MmadDtype> add_out = addOutQueue.AllocTensor<MmadDtype>();
|
||||
Cast(add_out, x_fp32, RoundMode::CAST_RINT, n);
|
||||
addOutQueue.EnQue(add_out);
|
||||
add_out = addOutQueue.DeQue<MmadDtype>();
|
||||
DataCopy(add_out_global[i * n], add_out, n);
|
||||
addOutQueue.FreeTensor(add_out);
|
||||
|
||||
LocalTensor<MmadDtype> gamma_local = gammaBuf.Get<MmadDtype>();
|
||||
LocalTensor<MmadDtype> out_local = outQueue.AllocTensor<MmadDtype>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
|
||||
// make sure precision is same in bf16 case
|
||||
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Cast(x_fp32, out_local, RoundMode::CAST_NONE, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, (float)1.0 / n, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSum(sqx, sqx, reduce_buf_local, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, (float)1.0, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
PipeSync<HardEvent::V_S>();
|
||||
float rstd_value = sqx.GetValue(0);
|
||||
PipeSync<HardEvent::S_V>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(x_fp32, x_fp32, rstd_value, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
if constexpr (std::is_same<MmadDtype, half>::value) {
|
||||
Cast(out_local, x_fp32, RoundMode::CAST_NONE, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(out_local, gamma_local, out_local, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else if constexpr (std::is_same<MmadDtype, bfloat16_t>::value) {
|
||||
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x_fp32, out_local, RoundMode::CAST_NONE, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, gamma_local, RoundMode::CAST_NONE, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Mul(x_fp32, x_fp32, sqx, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(out_local, x_fp32, RoundMode::CAST_RINT, n);
|
||||
PipeBarrier<PIPE_V>();
|
||||
PipeSync<HardEvent::V_MTE2>();
|
||||
}
|
||||
|
||||
outQueue.EnQue(out_local);
|
||||
out_local = outQueue.DeQue<MmadDtype>();
|
||||
DataCopy(out_global[i * n], out_local, n);
|
||||
outQueue.FreeTensor(out_local);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ void ParallelAllGather(__gm__ MmadDtype *gm_dst, __gm__ MmadDtype *gm_src,
|
||||
uint32_t core_buf_offset, uint32_t data_len)
|
||||
{
|
||||
GlobalTensor<MmadDtype> src_global;
|
||||
GlobalTensor<MmadDtype> dst_global;
|
||||
src_global.SetGlobalBuffer(gm_src);
|
||||
dst_global.SetGlobalBuffer(gm_dst);
|
||||
|
||||
constexpr uint32_t PIPELINE_COPY_NUM = sizeof(allgatherBuf) / sizeof(allgatherBuf[0]);
|
||||
TEventID ev_mte3_mte2[PIPELINE_COPY_NUM];
|
||||
TEventID ev_mte2_mte3[PIPELINE_COPY_NUM];
|
||||
LocalTensor<MmadDtype> local_tensors[PIPELINE_COPY_NUM];
|
||||
|
||||
for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) {
|
||||
ev_mte3_mte2[i] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_MTE2>();
|
||||
ev_mte2_mte3[i] = GetTPipePtr()->AllocEventID<HardEvent::MTE2_MTE3>();
|
||||
SetFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
|
||||
local_tensors[i] = allgatherBuf[i].Get<MmadDtype>();
|
||||
}
|
||||
|
||||
uint32_t offset = core_buf_offset;
|
||||
uint32_t copy_len = max_ub_ping_pong_size; // num of MmadDtype, not the byte length
|
||||
uint32_t copy_count = DivCeil(data_len, copy_len);
|
||||
uint32_t pipe_id = 0;
|
||||
|
||||
for (uint32_t i = 0; i < copy_count; i++) {
|
||||
uint32_t actual_copy_len =
|
||||
(i == copy_count - 1) ? (data_len - i * copy_len) : copy_len;
|
||||
|
||||
auto &local_tensor = local_tensors[pipe_id];
|
||||
|
||||
WaitFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[pipe_id]);
|
||||
DataCopy(local_tensor, src_global[offset], actual_copy_len);
|
||||
SetFlag<HardEvent::MTE2_MTE3>(ev_mte2_mte3[pipe_id]);
|
||||
WaitFlag<HardEvent::MTE2_MTE3>(ev_mte2_mte3[pipe_id]);
|
||||
DataCopy(dst_global[offset], local_tensor, actual_copy_len);
|
||||
SetFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[pipe_id]);
|
||||
|
||||
offset += actual_copy_len;
|
||||
pipe_id = (pipe_id + 1) % PIPELINE_COPY_NUM;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) {
|
||||
WaitFlag<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_MTE2>(ev_mte3_mte2[i]);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_MTE3>(ev_mte2_mte3[i]);
|
||||
}
|
||||
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
}
|
||||
|
||||
__gm__ MmadDtype *gm_out;
|
||||
__gm__ MmadDtype *gm_add_input;
|
||||
__gm__ MmadDtype *gm_add_output;
|
||||
__gm__ MmadDtype *gm_gamma;
|
||||
__ubuf__ int32_t *ub_ctrl_flag;
|
||||
|
||||
int32_t batch_size;
|
||||
int32_t m;
|
||||
int32_t k;
|
||||
int32_t n;
|
||||
int32_t m0;
|
||||
int32_t k0;
|
||||
int32_t n0;
|
||||
|
||||
int32_t m_loop;
|
||||
int32_t n_loop;
|
||||
int32_t k_loop;
|
||||
int32_t core_loop;
|
||||
int32_t core_idx;
|
||||
|
||||
int32_t rank;
|
||||
int32_t rank_size;
|
||||
int32_t tiling_key;
|
||||
int32_t swizzl_count;
|
||||
bool swizzl_direct;
|
||||
|
||||
bool trans_a;
|
||||
bool trans_b;
|
||||
bool is_int8;
|
||||
bool is_91093;
|
||||
bool is_gather_add_out;
|
||||
|
||||
int32_t aiv_idx;
|
||||
int32_t other_rank;
|
||||
int32_t core_num;
|
||||
int32_t max_ub_single_dma_size;
|
||||
int32_t max_ub_ping_pong_size;
|
||||
|
||||
int32_t gm_c_pingpong_size;
|
||||
int32_t withSerialMode;
|
||||
int32_t tag;
|
||||
int32_t comm_npu_split;
|
||||
int32_t comm_data_split;
|
||||
int32_t comm_direct;
|
||||
|
||||
int32_t core_count;
|
||||
bool is_deterministic;
|
||||
|
||||
QuantGranularity dequant_granularity;
|
||||
int32_t dequant_group_size;
|
||||
QuantGranularity quant_granularity;
|
||||
int32_t quant_group_size;
|
||||
|
||||
WorkspaceInfo workspace_info;
|
||||
int32_t ag_dim;
|
||||
int32_t rs_dim;
|
||||
bool inner_dim_is_Ag;
|
||||
bool weight_nz{false};
|
||||
|
||||
float epsilon;
|
||||
|
||||
TPipe pipe;
|
||||
AscendC::TBufPool<TPosition::VECCALC, TBUF_POOL_MAX_BUFID_SIZE> step1BufPool;
|
||||
AscendC::TBufPool<TPosition::VECCALC, TBUF_POOL_MAX_BUFID_SIZE> step2BufPool;
|
||||
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, TQUE_DEPTH> inQueueX, inQueueY;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> outQueueZ;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> addOutQueue;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, TQUE_DEPTH> outQueue;
|
||||
|
||||
AscendC::TBuf<TPosition::VECCALC> ctrlBuf;
|
||||
AscendC::TBuf<TPosition::VECCALC> gammaBuf;
|
||||
AscendC::TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
AscendC::TBuf<TPosition::VECCALC> sqxBuf;
|
||||
AscendC::TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
AscendC::TBuf<TPosition::VECCALC> allgatherBuf[2];
|
||||
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H
|
||||
#define MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H
|
||||
|
||||
#include <cstdint>
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
enum QuantGranularity : int {
|
||||
QUANT_GRANULARITY_UNDEFINED = -1,
|
||||
PER_TENSOR = 0,
|
||||
PER_CHANNEL = 1,
|
||||
PER_GROUP = 2,
|
||||
QUANT_GRANULARITY_MAX = 3,
|
||||
};
|
||||
|
||||
struct Opshape {
|
||||
int32_t batchSize = 1;
|
||||
int32_t m = -1;
|
||||
int32_t k = -1;
|
||||
int32_t n = -1;
|
||||
};
|
||||
|
||||
struct PPTilingData {
|
||||
Opshape opShape = {};
|
||||
int32_t m0 = 1;
|
||||
int32_t k0 = 1;
|
||||
int32_t n0 = 1;
|
||||
int32_t mLoop = 1;
|
||||
int32_t kLoop = 1;
|
||||
int32_t nLoop = 1;
|
||||
int32_t coreLoop = 1;
|
||||
int32_t swizzlCount = 1;
|
||||
int32_t swizzlDirect = 0;
|
||||
uint32_t tilingKey = 0;
|
||||
int32_t blockDim = 1;
|
||||
int32_t splitK = 0;
|
||||
bool weightNz = false;
|
||||
bool isTransA = false;
|
||||
bool isTransB = false;
|
||||
bool isGatherAddOut = false;
|
||||
};
|
||||
|
||||
struct CommTilingData {
|
||||
int32_t rank = 1;
|
||||
int32_t rankSize = 1;
|
||||
int32_t pValue = 1;
|
||||
int32_t ubMoveNum = 1;
|
||||
int32_t write2OtherRank = 0;
|
||||
int32_t withSerialMode = 0;
|
||||
int32_t tag = 0;
|
||||
int32_t commNpuSplit = 1;
|
||||
int32_t commDataSplit = 1;
|
||||
int32_t commDirect = 0;
|
||||
int32_t lenPerLoop = 1;
|
||||
int32_t is91093 = 0;
|
||||
int32_t buffer_size = 0;
|
||||
};
|
||||
|
||||
struct RmsNormTilingData {
|
||||
RmsNormTiling tiling{};
|
||||
uint32_t loopCount;
|
||||
uint32_t calcBytes;
|
||||
float epsilon{};
|
||||
};
|
||||
|
||||
struct QuantInfo {
|
||||
QuantGranularity dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
int32_t dequantGroupSize = -1;
|
||||
QuantGranularity quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED;
|
||||
int32_t quantGroupSize = -1;
|
||||
};
|
||||
|
||||
struct MatmulAllreduceAddRmsnormInfo {
|
||||
PPTilingData ppTilingData{};
|
||||
CommTilingData commTilingData{};
|
||||
RmsNormTilingData rmsnormTilingData{};
|
||||
QuantInfo quantInfo{};
|
||||
};
|
||||
|
||||
struct MatmulAllreduceAddRmsnormTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling;
|
||||
MatmulAllreduceAddRmsnormInfo matmulAllreduceAddRmsnormInfo;
|
||||
};
|
||||
|
||||
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H
|
||||
@@ -0,0 +1,414 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H
|
||||
#define MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H
|
||||
|
||||
#include <type_traits>
|
||||
#include "kernel_operator.h"
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr int64_t ND2NZ_STRIDE_LIMIT = 65536;
|
||||
constexpr int32_t AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID = 12;
|
||||
constexpr int32_t MAX_BLOCK_COUNT = 2;
|
||||
constexpr int32_t BLOCK_COUNT_3 = 3;
|
||||
constexpr int32_t BLOCK_COUNT_4 = 4;
|
||||
constexpr int32_t TILE_BLOCK_MOD = 2;
|
||||
|
||||
constexpr int32_t BLOCK_SIZE_32B = 32;
|
||||
constexpr int32_t BLOCK_SIZE_256B = 256;
|
||||
constexpr int32_t BLOCK_SIZE_512B = 512;
|
||||
|
||||
constexpr int32_t FFTS_SYNC_INTERNEL_MODE = 0;
|
||||
constexpr int32_t FFTS_SYNC_AICORE_GROUP_MODE = 2;
|
||||
|
||||
constexpr int32_t SWIZZL_MASK = 0b100000;
|
||||
constexpr int32_t TRANS_A_MASK = 0b010000;
|
||||
constexpr int32_t TRANS_B_MASK = 0b001000;
|
||||
constexpr int32_t INT8_MASK = 0b000100;
|
||||
constexpr int32_t BIAS_MASK = 0b000010;
|
||||
|
||||
template <typename T, size_t SIZE>
|
||||
struct BaseBlock {
|
||||
static_assert((SIZE & (SIZE - 1)) == 0, "Invalid block size");
|
||||
static constexpr size_t size = SIZE / sizeof(T);
|
||||
|
||||
static __aicore__ inline size_t Count(size_t len)
|
||||
{
|
||||
return (len + size - 1) / size;
|
||||
}
|
||||
|
||||
static __aicore__ inline bool IsAligned(size_t len)
|
||||
{
|
||||
return len % size == 0;
|
||||
}
|
||||
|
||||
static __aicore__ inline size_t AlignUp(size_t len)
|
||||
{
|
||||
return (len + size - 1) & ~(size - 1);
|
||||
}
|
||||
|
||||
static __aicore__ inline size_t AlignDown(size_t len)
|
||||
{
|
||||
return len & ~(size - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Block32B = BaseBlock<T, BLOCK_SIZE_32B>;
|
||||
|
||||
template <typename T>
|
||||
using Block256B = BaseBlock<T, BLOCK_SIZE_256B>;
|
||||
|
||||
template <typename T>
|
||||
using Block512B = BaseBlock<T, BLOCK_SIZE_512B>;
|
||||
|
||||
struct WorkspaceInfo {
|
||||
__gm__ uint8_t *gm_a_align{ nullptr };
|
||||
__gm__ uint8_t *gm_b_align{ nullptr };
|
||||
__gm__ uint8_t *gm_accum{ nullptr };
|
||||
__gm__ uint8_t *gm_dequant_param{ nullptr };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline LocalTensor<T> CreateLocalTensor(__ubuf__ T *addr)
|
||||
{
|
||||
LocalTensor<T> tensor;
|
||||
TBuffAddr taddr;
|
||||
taddr.bufferAddr = reinterpret_cast<uint64_t>(addr);
|
||||
tensor.SetAddr(taddr);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline LocalTensor<T> CreateLocalTensor(uint32_t buffer_offset)
|
||||
{
|
||||
LocalTensor<T> tensor;
|
||||
tensor.address_.bufferAddr = buffer_offset;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline LocalTensor<T> CreateLocalTensor(uint32_t buffer_offset, uint8_t logic_pos)
|
||||
{
|
||||
LocalTensor<T> tensor;
|
||||
tensor.address_.logicPos = logic_pos;
|
||||
tensor.address_.bufferAddr = buffer_offset;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct IntrinsicCopyGmToL1Nd2Nz {
|
||||
static __aicore__ inline void move(
|
||||
__cbuf__ T *dst, __gm__ T *src,
|
||||
uint8_t sid, uint16_t ndNum, uint16_t nValue, uint16_t dValue,
|
||||
uint16_t srcNdMatrixStride, uint16_t srcDValue, uint16_t dstNzC0Stride,
|
||||
uint16_t dstNzNStride, uint16_t dstNzMatrixStride) {
|
||||
Nd2NzParams nd2nzParams(
|
||||
ndNum, nValue, dValue,
|
||||
srcNdMatrixStride, srcDValue, dstNzC0Stride,
|
||||
dstNzNStride, dstNzMatrixStride
|
||||
);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
LocalTensor<T> dstTensor;
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
GlobalTensor<T> srcTensor;
|
||||
srcTensor.SetGlobalBuffer(src);
|
||||
DataCopy(dstTensor, srcTensor, nd2nzParams);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CopyGmToL1Nd2zN {
|
||||
static __aicore__ inline void move(
|
||||
__cbuf__ T *dst, __gm__ T *src,
|
||||
uint16_t nValue, uint16_t dValue, uint32_t srcDValue, uint16_t dstNzC0Stride) {
|
||||
constexpr int BLOCK_LEN = 32 / sizeof(T);
|
||||
if (srcDValue < ND2NZ_STRIDE_LIMIT) {
|
||||
IntrinsicCopyGmToL1Nd2Nz<T>::move(
|
||||
dst,
|
||||
src,
|
||||
0,
|
||||
1,
|
||||
nValue,
|
||||
dValue,
|
||||
0,
|
||||
srcDValue,
|
||||
dstNzC0Stride,
|
||||
1,
|
||||
0
|
||||
);
|
||||
} else {
|
||||
for (int i = 0; i < nValue; i++) {
|
||||
IntrinsicCopyGmToL1Nd2Nz<T>::move(
|
||||
dst + i * BLOCK_LEN,
|
||||
src + i * srcDValue,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
dValue,
|
||||
0,
|
||||
0,
|
||||
dstNzC0Stride,
|
||||
0,
|
||||
0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
__aicore__ inline void AlignJudge(bool trans_a, bool trans_b, int32_t m, int32_t k, int32_t n, int32_t m_align,
|
||||
int32_t k_align, int32_t n_align, int32_t &aligned_a, int32_t &aligned_b)
|
||||
{
|
||||
if (!trans_a) {
|
||||
aligned_a = k != k_align;
|
||||
} else {
|
||||
aligned_a = (m != m_align && m != 1);
|
||||
}
|
||||
|
||||
if (!trans_b) {
|
||||
aligned_b = (n != n_align);
|
||||
} else {
|
||||
aligned_b = (k != k_align);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline WorkspaceInfo GetWorkspaceInfo(__gm__ uint8_t *gm_workspace, int32_t batch_size, int32_t m,
|
||||
int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool trans_a, bool trans_b,
|
||||
int32_t mmad_dsize, bool has_a_align, bool has_b_align, bool has_accum = false, bool has_dequant_param = false)
|
||||
{
|
||||
WorkspaceInfo workspace_info;
|
||||
uint64_t workspace_offset = 0;
|
||||
|
||||
if (has_a_align) {
|
||||
workspace_info.gm_a_align = gm_workspace + workspace_offset;
|
||||
workspace_offset += static_cast<uint64_t>(batch_size) * (trans_a ? k * m_align : m * k_align) * mmad_dsize;
|
||||
}
|
||||
|
||||
if (has_b_align) {
|
||||
workspace_info.gm_b_align = gm_workspace + workspace_offset;
|
||||
workspace_offset += static_cast<uint64_t>(batch_size) * (trans_b ? n * k_align : k * n_align) * mmad_dsize;
|
||||
}
|
||||
|
||||
if (has_accum) {
|
||||
workspace_info.gm_accum = gm_workspace + workspace_offset;
|
||||
workspace_offset += static_cast<uint64_t>(batch_size) * m * n * sizeof(int32_t);
|
||||
}
|
||||
|
||||
if (has_dequant_param) {
|
||||
workspace_info.gm_dequant_param = gm_workspace + workspace_offset;
|
||||
workspace_offset += n * sizeof(float32_t);
|
||||
}
|
||||
|
||||
return workspace_info;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void CopyCubfToBt(uint64_t dst, __cbuf__ T *src, uint16_t convControl, uint16_t nBurst,
|
||||
uint16_t lenBurst, uint16_t sourceGap, uint16_t dstGap)
|
||||
{
|
||||
DataCopyParams intriParams(nBurst, lenBurst, sourceGap, dstGap);
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::C2);
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
DataCopy(dstTensor, srcTensor, intriParams);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void CopyGmToCbuf(__cbuf__ T *dst, __gm__ T *src, uint8_t sid, uint16_t nBurst,
|
||||
uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride, pad_t padMode)
|
||||
{
|
||||
DataCopyParams intriParams(nBurst, lenBurst, srcStride, dstStride);
|
||||
GlobalTensor<T> srcTensor;
|
||||
srcTensor.SetGlobalBuffer(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
LocalTensor<T> dstTensor;
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, logicpos);
|
||||
DataCopy(dstTensor, srcTensor, intriParams);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void SetFpc(__fbuf__ T *src)
|
||||
{
|
||||
LocalTensor<T> tensor;
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
tensor = CreateLocalTensor<T>(src_buffer_offset);
|
||||
SetFixPipeConfig(tensor);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void LoadCbufToCaTranspose(__ca__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat,
|
||||
uint16_t srcStride, uint16_t dstStride, bool addrmode,
|
||||
uint16_t dstFracStride)
|
||||
{
|
||||
LoadData2dTransposeParams params(
|
||||
indexID,
|
||||
repeat,
|
||||
srcStride,
|
||||
dstStride,
|
||||
dstFracStride,
|
||||
addrmode
|
||||
);
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::A2);
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
LoadDataWithTranspose(dstTensor, srcTensor, params);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void LoadCbufToCbTranspose(__cb__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat,
|
||||
uint16_t srcStride, uint16_t dstStride, bool addrmode,
|
||||
uint16_t dstFracStride)
|
||||
{
|
||||
LoadData2dTransposeParams params(
|
||||
indexID,
|
||||
repeat,
|
||||
srcStride,
|
||||
dstStride,
|
||||
dstFracStride,
|
||||
addrmode
|
||||
);
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::B2);
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
LoadDataWithTranspose(dstTensor, srcTensor, params);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void LoadCbufToCa(__ca__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat,
|
||||
uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose,
|
||||
uint8_t addr_cal_mode)
|
||||
{
|
||||
LoadData2dParams params(
|
||||
baseIdx,
|
||||
repeat,
|
||||
srcStride,
|
||||
sid,
|
||||
dstStride,
|
||||
transpose,
|
||||
addr_cal_mode
|
||||
);
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::A2);
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
LoadData(dstTensor, srcTensor, params);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
__aicore__ inline void LoadCbufToCb(__cb__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat,
|
||||
uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose,
|
||||
uint8_t addr_cal_mode)
|
||||
{
|
||||
LoadData2dParams params(
|
||||
baseIdx,
|
||||
repeat,
|
||||
srcStride,
|
||||
sid,
|
||||
dstStride,
|
||||
transpose,
|
||||
addr_cal_mode
|
||||
);
|
||||
uint32_t src_buffer_offset = reinterpret_cast<uint64_t>(src);
|
||||
uint32_t dst_buffer_offset = reinterpret_cast<uint64_t>(dst);
|
||||
uint8_t src_logicpos = static_cast<uint8_t>(TPosition::C1);
|
||||
uint8_t dst_logicpos = static_cast<uint8_t>(TPosition::B2);
|
||||
LocalTensor<T> srcTensor;
|
||||
LocalTensor<T> dstTensor;
|
||||
srcTensor = CreateLocalTensor<T>(src_buffer_offset, src_logicpos);
|
||||
dstTensor = CreateLocalTensor<T>(dst_buffer_offset, dst_logicpos);
|
||||
LoadData(dstTensor, srcTensor, params);
|
||||
}
|
||||
|
||||
|
||||
__aicore__ inline void GetBlockIdx(int32_t loop_idx, int32_t m_loop, int32_t n_loop, int32_t swizzl_direction,
|
||||
int32_t swizzl_count, int64_t &m_idx, int64_t &n_idx)
|
||||
{
|
||||
uint32_t in_batch_idx = loop_idx % (m_loop * n_loop);
|
||||
if (swizzl_direction == 0) {
|
||||
uint32_t tile_block_loop = (m_loop + swizzl_count - 1) / swizzl_count;
|
||||
uint32_t tile_block_idx = in_batch_idx / (swizzl_count * n_loop);
|
||||
uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * n_loop);
|
||||
|
||||
uint32_t n_row = swizzl_count;
|
||||
if (tile_block_idx == tile_block_loop - 1) {
|
||||
n_row = m_loop - swizzl_count * tile_block_idx;
|
||||
}
|
||||
m_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_row;
|
||||
n_idx = in_tile_block_idx / n_row;
|
||||
if (tile_block_idx % TILE_BLOCK_MOD != 0) {
|
||||
n_idx = n_loop - n_idx - 1;
|
||||
}
|
||||
} else if (swizzl_direction == 1) {
|
||||
uint32_t tile_block_loop = (n_loop + swizzl_count - 1) / swizzl_count;
|
||||
uint32_t tile_block_idx = in_batch_idx / (swizzl_count * m_loop);
|
||||
uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * m_loop);
|
||||
|
||||
uint32_t n_col = swizzl_count;
|
||||
if (tile_block_idx == tile_block_loop - 1) {
|
||||
n_col = n_loop - swizzl_count * tile_block_idx;
|
||||
}
|
||||
m_idx = in_tile_block_idx / n_col;
|
||||
n_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_col;
|
||||
if (tile_block_idx % TILE_BLOCK_MOD != 0) {
|
||||
m_idx = m_loop - m_idx - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <pipe_t pipe>
|
||||
__aicore__ inline void FFTSCrossCoreSync(uint64_t mode, uint64_t flag_id)
|
||||
{
|
||||
uint64_t config = 1 | (mode << 4) | (flag_id << 8);
|
||||
ffts_cross_core_sync(pipe, config);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__aicore__ GlobalTensor<T> CreateGlobalTensor(__gm__ T *addr)
|
||||
{
|
||||
GlobalTensor<T> tensor;
|
||||
tensor.SetGlobalBuffer(addr);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_H
|
||||
@@ -807,6 +807,36 @@ at::Tensor npu_sparse_flash_attention(
|
||||
output);
|
||||
return output;
|
||||
}
|
||||
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm(
|
||||
const at::Tensor &x1,
|
||||
const at::Tensor &x2,
|
||||
const at::Tensor &residual,
|
||||
const at::Tensor &gamma,
|
||||
c10::string_view group_tp,
|
||||
int64_t tp_rank_size,
|
||||
int64_t tp_rank_id,
|
||||
double epsilon,
|
||||
bool is_trans_b,
|
||||
bool is_gather_add_out)
|
||||
{
|
||||
at::Tensor output = at::empty_like(residual);
|
||||
at::Tensor add_out = at::empty_like(residual);
|
||||
|
||||
std::string group_tp_str(group_tp);
|
||||
|
||||
char *group_tp_ptr = group_tp_str.data();
|
||||
|
||||
float epsilon_f = static_cast<float>(epsilon);
|
||||
EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm,
|
||||
// input
|
||||
x1, x2, residual, gamma,
|
||||
// attr
|
||||
group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out,
|
||||
// output
|
||||
output, add_out);
|
||||
|
||||
return {output, add_out};
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -921,4 +951,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
|
||||
ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
|
||||
str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)");
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm);
|
||||
}
|
||||
|
||||
@@ -264,6 +264,23 @@ at::Tensor npu_sparse_flash_attention_meta(
|
||||
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
|
||||
return output;
|
||||
}
|
||||
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm_meta(
|
||||
const at::Tensor &x1,
|
||||
const at::Tensor &x2,
|
||||
const at::Tensor &residual,
|
||||
const at::Tensor &gamma,
|
||||
c10::string_view group_tp,
|
||||
int64_t tp_rank_size,
|
||||
int64_t tp_rank_id,
|
||||
double epsilon,
|
||||
bool is_trans_b,
|
||||
bool is_gather_add_out)
|
||||
{
|
||||
at::Tensor output = at::empty_like(residual);
|
||||
at::Tensor add_out = at::empty_like(residual);
|
||||
|
||||
return {output, add_out};
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@@ -296,5 +313,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
|
||||
// MoE dispatch-ffn-combine
|
||||
ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta);
|
||||
// matmul allreduce add rmsnorm
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
|
||||
}
|
||||
}
|
||||
|
||||
135
tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py
Normal file
135
tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
|
||||
global_rank_id = 0
|
||||
|
||||
|
||||
def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon):
|
||||
c_ret = torch.nn.functional.linear(a, b)
|
||||
dist.all_reduce(c_ret)
|
||||
rmsnorm_ret, _, add_ret = torch_npu.npu_add_rms_norm(
|
||||
c_ret, residual, gamma, epsilon)
|
||||
return rmsnorm_ret, add_ret
|
||||
|
||||
|
||||
def worker(rank, ep_world_size, batch_size, m, k, n):
|
||||
global global_rank_id
|
||||
global_rank_id = rank
|
||||
rank = rank
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
dist.init_process_group(backend="hccl",
|
||||
rank=rank,
|
||||
world_size=ep_world_size)
|
||||
|
||||
ep_ranks_list = list(np.arange(0, ep_world_size))
|
||||
|
||||
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
|
||||
torch_npu.npu.set_device(rank)
|
||||
ep_hcomm_info = ep_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(rank)
|
||||
|
||||
torch_npu.npu.synchronize(rank)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2, residual, gamma, ep_hcomm_info, epsilon,
|
||||
is_trans_b, is_allgather_add_out):
|
||||
out1, add_out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, ep_hcomm_info, ep_world_size,
|
||||
global_rank_id, epsilon, is_trans_b, is_allgather_add_out)
|
||||
return out1, add_out1
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
USE_ONES = False
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
if USE_ONES:
|
||||
x1 = torch.ones([m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.ones([n, k], dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank)
|
||||
|
||||
if USE_ONES:
|
||||
residual = torch.full([m, n], 2048, dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
residual = torch.full([m, n], 0, dtype=DTYPE).npu(rank)
|
||||
|
||||
gamma = torch.full([n], 1, dtype=DTYPE).npu(rank)
|
||||
|
||||
epsilon = 1e-5
|
||||
is_trans_b = True
|
||||
is_allgather_add_out = True
|
||||
warnup_cnt = 5
|
||||
repeat_cnt = 10
|
||||
|
||||
def run_golden_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
golden_out, golden_add_out = golden_op_matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, epsilon)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return golden_out, golden_add_out
|
||||
|
||||
run_golden_case(warnup_cnt)
|
||||
|
||||
golden_out, golden_add_out = run_golden_case(repeat_cnt)
|
||||
golden_out = golden_out.detach().cpu()
|
||||
golden_add_out = golden_add_out.detach().cpu()
|
||||
|
||||
mod = Module().npu()
|
||||
opt_model = torch.compile(mod, backend=npu_backend)
|
||||
|
||||
def run_custom_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
out, add_out = opt_model(x1, x2, residual, gamma, ep_hcomm_info,
|
||||
epsilon, is_trans_b, is_allgather_add_out)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return out, add_out
|
||||
|
||||
# warn up
|
||||
run_custom_case(warnup_cnt)
|
||||
|
||||
out, add_out = run_custom_case(repeat_cnt)
|
||||
out = out.detach().cpu()
|
||||
add_out = add_out.detach().cpu()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
torch.testing.assert_close(golden_out, out, atol=0.1, rtol=0.005)
|
||||
torch.testing.assert_close(golden_add_out, add_out, atol=0.1, rtol=0.005)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_matmul_allreduce_add_rmsnorm_kernel():
|
||||
ep_world_size = 8
|
||||
batch_size = 1
|
||||
m = 10000
|
||||
k = 1024
|
||||
n = 5120
|
||||
args = (ep_world_size, batch_size, m, k, n)
|
||||
mp.spawn(worker, args=args, nprocs=ep_world_size, join=True)
|
||||
Reference in New Issue
Block a user