From ba9cda9dfd83765a7e79a074c033818ff5a76bd6 Mon Sep 17 00:00:00 2001 From: Trunrain <270250579@qq.com> Date: Wed, 10 Dec 2025 09:05:33 +0800 Subject: [PATCH] [Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606) What this PR does / why we need it? Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce, Add, and RMSNorm Does this PR introduce _any_ user-facing change? No How was this patch tested? vLLM version: v0.11.2 vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: tongrunze Co-authored-by: tongrunze --- csrc/build_aclnn.sh | 15 +- .../op_host/CMakeLists.txt | 51 ++ .../aclnn_matmul_allreduce_add_rmsnorm.cpp | 89 +++ .../aclnn_matmul_allreduce_add_rmsnorm.h | 52 ++ .../matmul_allreduce_add_rmsnorm_def.cpp | 68 ++ .../matmul_allreduce_add_rmsnorm_proto.cpp | 68 ++ .../matmul_allreduce_add_rmsnorm_tiling.cpp | 619 +++++++++++++++ .../matmul_allreduce_add_rmsnorm_workspace.h | 79 ++ .../matmul_allreduce_add_rmsnorm.cpp | 50 ++ .../matmul_allreduce_add_rmsnorm_aic_kernel.h | 359 +++++++++ .../matmul_allreduce_add_rmsnorm_aiv_kernel.h | 702 ++++++++++++++++++ .../matmul_allreduce_add_rmsnorm_tiling.h | 101 +++ .../matmul_allreduce_add_rmsnorm_utils.h | 414 +++++++++++ csrc/torch_binding.cpp | 34 + csrc/torch_binding_meta.cpp | 19 + .../ops/test_matmul_allreduce_add_rmsnorm.py | 135 ++++ 16 files changed, 2854 insertions(+), 1 deletion(-) create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.cpp create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.h create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_def.cpp create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_proto.cpp create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_tiling.cpp create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_workspace.h create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aic_kernel.h create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aiv_kernel.h create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_tiling.h create mode 100644 csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_utils.h create mode 100644 tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 758856b7..709a9813 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -11,7 +11,20 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then exit 0 elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then # ASCEND910B (A2) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" + # depdendency: catlass + git config --global --add safe.directory "$ROOT_DIR" + CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass/include + if [[ ! -d "${CATLASS_PATH}" ]]; then + echo "depdendency catlass is missing, try to fetch it..." + if ! git submodule update --init --recursive; then + echo "fetch failed" + exit 1 + fi + fi + ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) + export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} + + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt b/csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt new file mode 100644 index 00000000..1dd7f8e0 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME MatmulAllreduceAddRmsnormTensorList + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnInner PRIVATE + matmul_allreduce_add_rmsnorm_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_matmul_allreduce_add_rmsnorm.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + aclnn_matmul_allreduce_add_rmsnorm.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + aclnn_matmul_allreduce_add_rmsnorm.cpp + ) +endif () + +target_sources(optiling PRIVATE + matmul_allreduce_add_rmsnorm_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + matmul_allreduce_add_rmsnorm_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_matmul_allreduce_add_rmsnorm.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.cpp new file mode 100644 index 00000000..ec71fa91 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/types.h" +#include "aclnn/opdev/platform.h" +#include "aclnn_matmul_allreduce_add_rmsnorm.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnormGetWorkspaceSize( + const aclTensor *x1, + const aclTensor *x2, + const aclTensor *residual, + const aclTensor *gamma, + char *groupTp, + int64_t tpRankSize, + int64_t tpRankId, + double epsilon, + bool isTransB, + bool isGatherAddOut, + const aclTensor *yOut, + const aclTensor *addOutOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnorm( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +extern "C" { +#endif + +aclnnStatus aclnnMatmulAllreduceAddRmsnormGetWorkspaceSize( + const aclTensor *x1, + const aclTensor *x2, + const aclTensor *residual, + const aclTensor *gamma, + char *groupTp, + int64_t tpRankSize, + int64_t tpRankId, + double epsilon, + bool isTransB, + bool isGatherAddOut, + const aclTensor *y, + const aclTensor *addOut, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + return aclnnInnerMatmulAllreduceAddRmsnormGetWorkspaceSize(x1, x2, residual, + gamma, groupTp, tpRankSize, tpRankId, epsilon, isTransB, isGatherAddOut, y, addOut, workspaceSize, executor); +} + +aclnnStatus aclnnMatmulAllreduceAddRmsnorm( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + return aclnnInnerMatmulAllreduceAddRmsnorm(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.h b/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.h new file mode 100644 index 00000000..b2920a24 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACLNN_MATMUL_ALLREDUCE_ADD_RMSNORM +#define ACLNN_MATMUL_ALLREDUCE_ADD_RMSNORM + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnMatmulAllreduceAddRmsnormGetWorkspaceSize( + const aclTensor *x1, + const aclTensor *x2, + const aclTensor *residual, + const aclTensor *gamma, + char *groupTp, + int64_t tpRankSize, + int64_t tpRankId, + double epsilon, + bool isTransB, + bool isGatherAddOut, + const aclTensor *y, + const aclTensor *addOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnMatmulAllreduceAddRmsnorm( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_def.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_def.cpp new file mode 100644 index 00000000..9e44dee2 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_def.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/op_def_registry.h" + +namespace ops{ +class MatmulAllreduceAddRmsnorm : public OpDef { +public: + explicit MatmulAllreduceAddRmsnorm(const char* name) : OpDef(name) + { + this->Input("x1") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("x2") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("residual") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gamma") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("add_out") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Attr("group_tp").String(); + this->Attr("tp_rank_size").Int(); + this->Attr("tp_rank_id").Int(); + this->Attr("epsilon").AttrType(OPTIONAL).Float(1e-6); + this->Attr("is_trans_b").AttrType(OPTIONAL).Bool(false); + this->Attr("is_gather_add_out").AttrType(OPTIONAL).Bool(false); + + this->MC2().HcclGroup({"group_tp"}); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(MatmulAllreduceAddRmsnorm); +} \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_proto.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_proto.cpp new file mode 100644 index 00000000..027ff88f --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_proto.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" + +namespace ge { + constexpr uint32_t RESIDUAL_INDEX = 3; + constexpr uint32_t OUTPUT_Y_INDEX = 0; + constexpr uint32_t OUTPUT_ADD_OUT_INDEX = 1; + constexpr int SHAPE_INDEX0 = 0; + constexpr int SHAPE_INDEX1 = 1; + constexpr int SHAPE_INDEX2 = 2; + constexpr int DIM_NUM_2 = 2; + constexpr int DIM_NUM_3 = 3; + +static void CloneShape(const gert::Shape* src, gert::Shape* dst) +{ + int ndim = src->GetDimNum(); + dst->SetDimNum(ndim); + for (int i = 0; i < ndim; ++i) { + dst->SetDim(i, src->GetDim(i)); + } +} + +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + const gert::Shape* residualShape = context->GetInputShape(RESIDUAL_INDEX); + int residualDimNum = residualShape->GetDimNum(); + + if (residualDimNum != DIM_NUM_2 && residualDimNum != DIM_NUM_3) { + return GRAPH_FAILED; + } + + gert::Shape* x1OutShape = context->GetOutputShape(OUTPUT_Y_INDEX); + gert::Shape* addOutShape = context->GetOutputShape(OUTPUT_ADD_OUT_INDEX); + CloneShape(residualShape, x1OutShape); + CloneShape(residualShape, addOutShape); + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) +{ + const auto residualDataType = context->GetInputDataType(RESIDUAL_INDEX); + context->SetOutputDataType(OUTPUT_Y_INDEX, residualDataType); + context->SetOutputDataType(OUTPUT_ADD_OUT_INDEX, residualDataType); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP(MatmulAllreduceAddRmsnorm) + .InferShape(InferShape) + .InferDataType(InferDataType); +} \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_tiling.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_tiling.cpp new file mode 100644 index 00000000..6a92e402 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_tiling.cpp @@ -0,0 +1,619 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "log/ops_log.h" +#include "error/ops_error.h" + +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "../op_kernel/matmul_allreduce_add_rmsnorm_tiling.h" +#include "matmul_allreduce_add_rmsnorm_workspace.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" + +typedef enum { + ATTR_TP_INDEX = 0, + ATTR_RANK_SIZE_INDEX, + ATTR_RANK_ID_INDEX, + ATTR_EPSILON_INDEX, + ATTR_IS_TRANS_B_INDEX, + ATTR_IS_GATHER_ADD_OUT_INDEX +} ATTR_TYPE; + +int32_t CeilDev(int32_t num, int32_t div) +{ + if (div == 0) { + return 0; + } + return (num + div - 1) / div; +} +static constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8; +static constexpr uint32_t BATCH_SIZE_ONE = 1; +static constexpr uint32_t DEFAULT_ROW = 128; +static constexpr uint32_t DEFAULT_COL = 256; +static constexpr uint32_t DEFAULT_SWIZZLE_COUNT = 4; +static constexpr int32_t VALID_UB_MOVE_NUM = 20480; +static constexpr int32_t COMMDATASPLIT_ONE = 1; +static constexpr int32_t COMM_DATA_DIRECT = 0; +static constexpr uint32_t ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT = 128; +static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_DATASPLIT_DEFAULT = 16; +static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 100; +static constexpr int32_t HALF_KBYTE = 512; +static constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT = 14; +static constexpr int32_t SWIZZLE_DIRECT_ONE = 1; +static constexpr int32_t COMMNPUSPLIT_ONE = 1; +static constexpr int32_t COMMDATASPLIT_SIXTEEN = 16; +constexpr int32_t SECOND_TO_MS = 1000; +constexpr int64_t MATMUL_BASE_100US = static_cast(1024) * 8192 * 1024; +constexpr int64_t ALLREDUCE_BASE_100US = 4096 * 1024; +constexpr double ONE_K = 1024.0; +constexpr double B1_FLOP_PER_MS = (364 * 0.8) * 1e9; +constexpr double DOUBLE = 2.0; +constexpr double HALF_PROB = 0.5; +constexpr int32_t CONDITION_M_ST = 0; +constexpr int32_t CONDITION_M_END = 1; +constexpr int32_t CONDITION_K_ST = 2; +constexpr int32_t CONDITION_K_END = 3; +constexpr int32_t CONDITION_N_ST = 4; +constexpr int32_t CONDITION_N_END = 5; +constexpr int32_t RANKSIZE_FOUR = 4; +constexpr int32_t RANKSIZE_EIGHT = 8; +constexpr int32_t DIV_TWO = 2; +constexpr int32_t LENPERLOOP_DEFAULT = 5120; +constexpr int32_t MIN_UB_MOVE_NUM = 5120; +constexpr int32_t MAX_UB_NUM = 97280; // 190 * 1024 / 2 +constexpr int32_t MAX_P_VALUE = 15; + +constexpr int32_t DIM_NUM_TWO = 2; +constexpr int32_t DIM_NUM_THREE = 3; +constexpr int32_t DIM_INDEX_ZERO = 0; +constexpr int32_t DIM_INDEX_ONE = 1; +constexpr int32_t DIM_INDEX_TWO = 2; + +static constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; + +static constexpr uint32_t USE_CORE_NUM = 20; + +static std::vector ALLREDUCE_UBMOVENUM_COEF = {{-1.72352427e+01, + 2.56887672e-03, + -8.21819480e+00, + 8.70965589e+01, + -3.63853858e-01, + 1.27789264e+01, + 1.29782183e+02, + 1.90250023e-02, + -3.48175441e+00, + 6.18921914e+03, + 3.77072171e+03, + -5.86895290e+01, + -8.70740991e-01, + -1.40262280e-04, + -2.81910331e-08, + 3.22795486e-05, + -4.84522320e-03, + 2.94839177e-01, + 2.97260958e-03, + 9.08844709e+01, + -5.80426209e-10, + 38.183465184603484}}; + +static std::map>> ALLREDUCE_EIGHT_RANK_FP16_M0_MAP = { + {128, + {{-1, 31220, -1, 2147483647, -1, 768}, + {31220, 36980, 1280, 2147483647, -1, 768}, + {36980, 2147483647, -1, 2147483647, -1, 768}, + {-1, 2147483647, -1, 2147483647, 768, 2147483647}}}, + {256, {{31220, 36980, -1, 1280, -1, 768}}}}; + +static std::map>> ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_MAP = { + {100, + {{-1, 3072, -1, 2147483647, -1, 768}, + {3072, 19680, -1, 3072, -1, 768}, + {-1, 3072, -1, 2147483647, 768, 1536}, + {3072, 19680, -1, 3072, 768, 1536}, + {-1, 2147483647, 1792, 2976, 1536, 13312}}}, + {30, + {{3072, 19680, 3072, 2147483647, -1, 768}, + {19680, 2147483647, -1, 3072, -1, 1536}, + {-1, 2147483647, -1, 1792, 1536, 13312}, + {-1, 768, 2976, 2147483647, 5376, 13312}, + {-1, 768, -1, 2147483647, 13312, 2147483647}, + {26880, 2147483647, -1, 3072, 13312, 2147483647}}}, + {20, + {{3072, 19680, 3072, 2147483647, 768, 1536}, + {19680, 2147483647, 3072, 2147483647, -1, 1536}, + {-1, 2147483647, 2976, 2147483647, 1536, 5376}, + {768, 2147483647, 2976, 2147483647, 5376, 13312}, + {768, 26880, -1, 2147483647, 13312, 2147483647}, + {26880, 2147483647, 3072, 2147483647, 13312, 2147483647}}}}; + +static std::vector ALLREDUCE_PVALUE_COEF = {{-4.23166350e+00, + 6.71137487e-04, + -1.33434156e+00, + 1.12915884e+01, + -7.85892737e-02, + 2.59059897e+00, + 3.22129881e+01, + -5.15776887e-02, + 9.15542742e-01, + 1.56322201e+03, + 3.61977421e+01, + -5.49544589e-01, + -2.66903417e-01, + -3.68521920e-05, + -6.40666333e-09, + 6.77406054e-06, + -9.92992099e-04, + 5.60658043e-02, + 2.69372863e-04, + 2.17222337e+01, + -1.17749660e-10, + 6.100544547671263}}; + +double GetMTETime(double mknGB, int32_t m0, int32_t n0, double aBindWidth = 3.0, double bBindWidth = 3.0) +{ + // 预估Matmul计算的MTE2搬运时间 + return DOUBLE * mknGB * (SECOND_TO_MS / ONE_K) * (1.0 / (n0 * aBindWidth) + 1.0 / (m0 * bBindWidth)); +} + +int32_t AllReduceUbMoveNum(int m, int k, int n) +{ + double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40; + double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n; + double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K); + double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL); + double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW); + double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2); + double matmulPredict = std::max(cubePredict, mteTimePredict); + double c0 = matmulPredict / commPredict; + double c1 = 1.0 * m * n / k; + double c2 = sqrt(c1); + double c3 = sqrt(1.0 * m * n) / k; + double c4 = c3 * c3; + double c5 = matmulPredict; + double c6 = commPredict; + double c7 = 1.0 * n / m; + double c8 = 1.0 * m * n / sqrt(k); + double c9 = 1.0 * m * n * sqrt(k); + double c10 = sqrt(1.0 * m * n) * k; + double c11 = sqrt(1.0 * m * n * k); + double c12 = sqrt(1.0 * m * n); + double c13 = 1.0 * k * k / sqrt(1.0 * m * n); + double c14 = 1.0 * k * k * sqrt(1.0 * m * n); + double ubMoveNumDouble = 0; + std::vector feats_update = {c0, + c1, + c2, + c3, + c4, + c5, + c6, + c7, + 1.0 / c0, + 1.0 / c1, + 1.0 / c2, + 1.0 / c3, + 1.0 / c4, + c8, + c9, + c10, + c11, + c12, + c13, + 1.0 / c13, + c14, + 1}; + for (uint32_t i = 0; i < feats_update.size(); i++) { + ubMoveNumDouble += feats_update[i] * ALLREDUCE_UBMOVENUM_COEF[i]; + } + + return std::min(std::max(static_cast(ubMoveNumDouble) * HALF_KBYTE, MIN_UB_MOVE_NUM), MAX_UB_NUM); +} + +int32_t AllReducePValue(int m, int k, int n) +{ + double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40; + double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n; + double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K); + double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL); + double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW); + double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2); + double matmulPredict = std::max(cubePredict, mteTimePredict); + double c0 = matmulPredict / commPredict; + double c1 = 1.0 * m * n / k; + double c2 = sqrt(c1); + double c3 = sqrt(1.0 * m * n) / k; + double c4 = c3 * c3; + double c5 = matmulPredict; + double c6 = commPredict; + double c7 = 1.0 * n / m; + double c8 = 1.0 * m * n / sqrt(k); + double c9 = 1.0 * m * n * sqrt(k); + double c10 = sqrt(1.0 * m * n) * k; + double c11 = sqrt(1.0 * m * n * k); + double c12 = sqrt(1.0 * m * n); + double c13 = 1.0 * k * k / sqrt(1.0 * m * n); + double c14 = 1.0 * k * k * sqrt(1.0 * m * n); + double pValueDouble = 0; + std::vector feats_update = {c0, + c1, + c2, + c3, + c4, + c5, + c6, + c7, + 1.0 / c0, + 1.0 / c1, + 1.0 / c2, + 1.0 / c3, + 1.0 / c4, + c8, + c9, + c10, + c11, + c12, + c13, + 1.0 / c13, + c14, + 1}; + for (uint32_t i = 0; i < feats_update.size(); i++) { + pValueDouble += feats_update[i] * ALLREDUCE_PVALUE_COEF[i]; + } + + return std::min(std::max(static_cast(pValueDouble), 1), MAX_P_VALUE); +} + +int32_t GetValueFromMKNConditionMap( + int32_t m, int32_t k, int32_t n, int32_t defaultValue, std::map>> conditionMap) +{ + int32_t value = defaultValue; + for (auto &item : conditionMap) { + for (auto &condition : item.second) { + bool inRange = m > condition[CONDITION_M_ST] && m <= condition[CONDITION_M_END] && + k > condition[CONDITION_K_ST] && k <= condition[CONDITION_K_END] && + n > condition[CONDITION_N_ST] && n <= condition[CONDITION_N_END]; + if (inRange) { + return item.first; + } + } + } + return value; +} + +void AllReduceEightRankFP16GetDefaultTiling( + gert::TilingContext *context, PPTilingData &ppTilingData, CommTilingData &commTilingData) +{ + int32_t m = ppTilingData.opShape.m; + int32_t k = ppTilingData.opShape.k; + int32_t n = ppTilingData.opShape.n; + + ppTilingData.m0 = + GetValueFromMKNConditionMap(m, k, n, ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT, ALLREDUCE_EIGHT_RANK_FP16_M0_MAP); + + ppTilingData.k0 = DEFAULT_COL; + ppTilingData.n0 = ppTilingData.m0 == DEFAULT_ROW ? DEFAULT_COL : DEFAULT_ROW; + + ppTilingData.mLoop = CeilDev(m, ppTilingData.m0); + ppTilingData.nLoop = CeilDev(n, ppTilingData.n0); + ppTilingData.kLoop = CeilDev(k, ppTilingData.k0); + + ppTilingData.coreLoop = ppTilingData.opShape.batchSize * ppTilingData.mLoop * ppTilingData.nLoop; + ppTilingData.swizzlDirect = SWIZZLE_DIRECT_ONE; + ppTilingData.swizzlCount = DEFAULT_SWIZZLE_COUNT; + ppTilingData.tilingKey = 0; + ppTilingData.splitK = 0; + + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + ppTilingData.blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); + + commTilingData.ubMoveNum = + GetValueFromMKNConditionMap( + m, k, n, ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_MAP) * + HALF_KBYTE; + commTilingData.pValue = ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT; + + commTilingData.commDirect = COMM_DATA_DIRECT; + commTilingData.commNpuSplit = COMMNPUSPLIT_ONE; + commTilingData.commDataSplit = COMMDATASPLIT_SIXTEEN; + commTilingData.is91093 = 0; + commTilingData.withSerialMode = 0; + commTilingData.tag = 0; + commTilingData.write2OtherRank = 0; +} + +void GetDefaultTiling(gert::TilingContext *context, PPTilingData &ppTilingData, CommTilingData &commTilingData) +{ + int32_t m = ppTilingData.opShape.m; + int32_t k = ppTilingData.opShape.k; + int32_t n = ppTilingData.opShape.n; + + ppTilingData.m0 = DEFAULT_ROW; + ppTilingData.n0 = DEFAULT_COL; + ppTilingData.k0 = DEFAULT_COL; + + ppTilingData.mLoop = CeilDev(m, ppTilingData.m0); + ppTilingData.nLoop = CeilDev(n, ppTilingData.n0); + ppTilingData.kLoop = CeilDev(k, ppTilingData.k0); + ppTilingData.coreLoop = ppTilingData.opShape.batchSize * ppTilingData.mLoop * ppTilingData.nLoop; + + ppTilingData.swizzlDirect = m > n ? 0 : 1; + ppTilingData.swizzlCount = DEFAULT_SWIZZLE_COUNT; + ppTilingData.tilingKey = 0; + ppTilingData.splitK = 0; + + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + ppTilingData.blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); + + commTilingData.ubMoveNum = AllReduceUbMoveNum(m, k, n); + commTilingData.pValue = AllReducePValue(m, k, n); + commTilingData.commNpuSplit = commTilingData.rankSize; + commTilingData.commDataSplit = COMMDATASPLIT_ONE; + commTilingData.commDirect = COMM_DATA_DIRECT; + commTilingData.lenPerLoop = ppTilingData.m0 * ppTilingData.n0 * commTilingData.pValue * ppTilingData.blockDim; + commTilingData.lenPerLoop = commTilingData.lenPerLoop / commTilingData.rankSize; + commTilingData.is91093 = 0; + commTilingData.withSerialMode = 0; + commTilingData.tag = 0; + commTilingData.write2OtherRank = 0; +} + +static inline void GetRmsnormTilingData(RmsNormTilingData &rmsnormtiling, std::vector &shapeVec, + std::vector &oriShapeVec, uint32_t calcBytes = 0, uint32_t loopCount = 1, float ep = 1e-5) +{ + ge::Shape srcShape(shapeVec); + ge::Shape oriSrcShape(oriShapeVec); + uint32_t minValue = 0; + uint32_t maxValue = 0; + AscendC::GetRmsNormMaxMinTmpSize(srcShape, sizeof(uint16_t), maxValue, minValue, false); + + if (calcBytes < minValue) { + rmsnormtiling.calcBytes = minValue; + } else if (calcBytes > maxValue) { + rmsnormtiling.calcBytes = maxValue; + } else { + rmsnormtiling.calcBytes = calcBytes; + } + + optiling::RmsNormTiling tilingdata; + AscendC::GetRmsNormTilingInfo(srcShape, oriSrcShape, rmsnormtiling.calcBytes, sizeof(uint16_t), tilingdata, false); + size_t tilingSize = tilingdata.GetDataSize(); + tilingdata.SaveToBuffer(&rmsnormtiling.tiling, tilingSize); + rmsnormtiling.epsilon = ep; + rmsnormtiling.loopCount = loopCount; +} + +static inline void GetQuantTilingData(QuantInfo &quantInfo) +{ + quantInfo.dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + quantInfo.dequantGroupSize = -1; + quantInfo.quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + quantInfo.quantGroupSize = -1; +} + +static ge::graphStatus GetAttrAndSetTilingData( + gert::TilingContext *context, const char *nodeName, MatmulAllreduceAddRmsnormTilingData &tilingData) + +{ + CommTilingData &commTilingData = tilingData.matmulAllreduceAddRmsnormInfo.commTilingData; + PPTilingData &ppTilingData = tilingData.matmulAllreduceAddRmsnormInfo.ppTilingData; + RmsNormTilingData &rmsnormTilingData = tilingData.matmulAllreduceAddRmsnormInfo.rmsnormTilingData; + QuantInfo &quantInfo = tilingData.matmulAllreduceAddRmsnormInfo.quantInfo; + + auto attrs = context->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto RankSizePtr = attrs->GetAttrPointer(ATTR_RANK_SIZE_INDEX); + auto RankIdPtr = attrs->GetAttrPointer(ATTR_RANK_ID_INDEX); + + bool isTransB = *(attrs->GetAttrPointer(ATTR_IS_TRANS_B_INDEX)); + + ppTilingData.isTransA = false; + ppTilingData.isTransB = isTransB; + ppTilingData.isGatherAddOut = *(attrs->GetAttrPointer(ATTR_IS_GATHER_ADD_OUT_INDEX)); + + auto &opShape = ppTilingData.opShape; + auto &tensor0Shape = context->GetInputTensor(0)->GetOriginShape(); + uint32_t dimNum = tensor0Shape.GetDimNum(); + int64_t bs; + int64_t rankM; + int64_t rankK; + + if (dimNum == DIM_NUM_THREE) { + bs = tensor0Shape.GetDim(DIM_INDEX_ZERO); + rankM = tensor0Shape.GetDim(DIM_INDEX_ONE); + rankK = tensor0Shape.GetDim(DIM_INDEX_TWO); + } else if (dimNum == DIM_NUM_TWO) { + bs = BATCH_SIZE_ONE; + rankM = tensor0Shape.GetDim(DIM_INDEX_ZERO); + rankK = tensor0Shape.GetDim(DIM_INDEX_ONE); + } else { + const char *nodeName = context->GetNodeName(); + OPS_LOG_E(nodeName, "Tiling input dim error."); + return ge::GRAPH_FAILED; + } + + int64_t rankN = isTransB ? + context->GetInputTensor(1)->GetOriginShape().GetDim(DIM_INDEX_ZERO) : + context->GetInputTensor(1)->GetOriginShape().GetDim(DIM_INDEX_ONE); + + opShape.batchSize = BATCH_SIZE_ONE; + opShape.m = bs * rankM; + opShape.n = rankN; + opShape.k = rankK; + + commTilingData.rankSize = static_cast(*RankSizePtr); + commTilingData.rank = static_cast(*RankIdPtr); + if (commTilingData.rankSize == RANKSIZE_EIGHT) { + AllReduceEightRankFP16GetDefaultTiling(context, ppTilingData, commTilingData); + } else { + GetDefaultTiling(context, ppTilingData, commTilingData); + } + + uint32_t calcBytes = 0; + uint32_t sLength = 1; + std::vector shapeVec = {1, 1, rankN}; + std::vector oriShapeVec = shapeVec; + auto EpsilonPtr = attrs->GetAttrPointer(ATTR_EPSILON_INDEX); + float epsilon = static_cast(*EpsilonPtr); + GetRmsnormTilingData( + rmsnormTilingData, shapeVec, oriShapeVec, calcBytes, commTilingData.rankSize * sLength * rankN, epsilon); + GetQuantTilingData(quantInfo); + + return ge::GRAPH_SUCCESS; +} + +bool IsMatrixAligned(const int64_t &m, const int64_t &n, const bool &transpose, int nElemAlign) +{ + return (transpose ? m : n) % nElemAlign == 0; +} + +int64_t GetAlignedMatrixSize( + const int64_t &batchSize, const int64_t &m, const int64_t &n, const bool &transpose, int nElemAlign) +{ + int64_t nRow = transpose ? n : m; + int64_t nCol = transpose ? m : n; + int64_t nColAlign = (nCol + nElemAlign - 1) / nElemAlign * nElemAlign; + return batchSize * nRow * nColAlign; +} + +WorkspaceDetail GetWorkspaceDetail(CoCDataTypeDesc dataType, const MatMulInfo &mmInfo, const QuantInfo &quantInfo) +{ + WorkspaceDetail workspaceDetail; + + int32_t eleSize = COC_TYPE2ELE_SIZE.at(dataType); + int32_t nElemAlign = ALIGN_BYTES / eleSize; + + bool hasQuant = quantInfo.quantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + if (hasQuant || (!IsMatrixAligned(mmInfo.m, mmInfo.k, mmInfo.transA, nElemAlign) && mmInfo.m != 1)) { + workspaceDetail.matrixActivationSize = + GetAlignedMatrixSize(mmInfo.batchSize, mmInfo.m, mmInfo.k, mmInfo.transA, nElemAlign) * eleSize; + } + + bool hasDequant = quantInfo.dequantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + if ((hasDequant && !mmInfo.isInt8) || !IsMatrixAligned(mmInfo.k, mmInfo.n, mmInfo.transB, nElemAlign)) { + workspaceDetail.matrixWeightSize = + GetAlignedMatrixSize(mmInfo.batchSize, mmInfo.k, mmInfo.n, mmInfo.transB, nElemAlign) * eleSize; + } + + bool hasAccum = dataType == CoCDataTypeDesc::INT8INT8_INT32_BF16; + if (hasAccum) { + workspaceDetail.matrixIntermediateSize = + static_cast(mmInfo.batchSize) * mmInfo.m * mmInfo.n * sizeof(int32_t); + } + + if (mmInfo.isInt8) { + workspaceDetail.formatDequantParamSize = + mmInfo.k > mmInfo.n ? mmInfo.k * sizeof(float) : mmInfo.n * sizeof(float); + } + return workspaceDetail; +} + +void GetMmInfo(gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling, MatMulInfo *mmInfo) +{ + PPTilingData tempPPTilingData = tiling->matmulAllreduceAddRmsnormInfo.ppTilingData; + mmInfo->batchSize = tempPPTilingData.opShape.batchSize; + mmInfo->m = tempPPTilingData.opShape.m; + mmInfo->n = tempPPTilingData.opShape.n; + mmInfo->k = tempPPTilingData.opShape.k; + auto attrs = context->GetAttrs(); + mmInfo->transA = false; + mmInfo->transB = *(attrs->GetAttrPointer(ATTR_IS_TRANS_B_INDEX)); + mmInfo->withBias = false; + mmInfo->weightNz = false; + mmInfo->isInt8 = context->GetInputTensor(0)->GetDataType() == ge::DT_INT8; +} + +size_t GetUserWorkspaceSize(gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling) +{ + MatMulInfo mmInfo; + GetMmInfo(context, tiling, &mmInfo); + QuantInfo quantInfo = tiling->matmulAllreduceAddRmsnormInfo.quantInfo; + return GetWorkspaceDetail(FP16FP16_FP32_FP16, mmInfo, quantInfo).GetSize(); +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t *workSpaces = context->GetWorkspaceSizes(1); + OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + size_t systemWorkspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()); + MatmulAllreduceAddRmsnormTilingData *tilingData = context->GetTilingData(); + size_t userWorkspaceSize = GetUserWorkspaceSize(context, tilingData); + workSpaces[0] = userWorkspaceSize + systemWorkspaceSize; + return ge::GRAPH_SUCCESS; +} + +static void SetHcommCfg( + const gert::TilingContext *context, MatmulAllreduceAddRmsnormTilingData *tiling, const std::string groupTp) +{ + const char *nodeName = context->GetNodeName(); + uint32_t opType = OP_TYPE_ALL_TO_ALL; + std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh"; + + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupTp, opType, algConfigAllToAllStr); + mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling); +} + +static ge::graphStatus MatmulAllreduceAddRmsnormTilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + MatmulAllreduceAddRmsnormTilingData *tilingData = context->GetTilingData(); + OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + + OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "Tiling set workspace failed."), + return ge::GRAPH_FAILED); + SetHcommCfg(context, tilingData, "hcomms"); + + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return ge::GRAPH_FAILED); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + uint32_t aicNum_ = ascendcPlatform.GetCoreNumAic(); + context->SetBlockDim(aicNum_); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MatmulAllreduceAddRmsnormTilingFunc(gert::TilingContext *context) +{ + ge::graphStatus ret = MatmulAllreduceAddRmsnormTilingFuncImpl(context); + return ret; +} + +struct MatmulAllreduceAddRmsnormCompileInfo {}; +ge::graphStatus TilingParseForMatmulAllreduceAddRmsnorm(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(MatmulAllreduceAddRmsnorm) + .Tiling(MatmulAllreduceAddRmsnormTilingFunc) + .TilingParse(TilingParseForMatmulAllreduceAddRmsnorm); diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_workspace.h b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_workspace.h new file mode 100644 index 00000000..7da928f8 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_host/matmul_allreduce_add_rmsnorm_workspace.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_WORKSPACE_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_WORKSPACE_H + +#include + +#pragma once +const constexpr uint32_t ALIGN_BYTES = 512; +const constexpr int32_t INT8_ELE_SIZE = 1; +const constexpr int32_t FP_BF_16_ELE_SIZE = 2; + +enum CoCDataTypeDesc : int { + COC_DATA_TYPE_UNDEFINED = -1, + FP16FP16_FP32_FP16 = 0, + BF16BF16_FP32_BF16 = 1, + INT8INT8_INT32_FP16 = 2, + INT8INT8_INT32_BF16 = 3, + FP16INT8_INT32_FP16 = 4, + BF16INT8_INT32_BF16 = 5, + FP16INT8_FP32_FP16 = 6, + BF16INT8_FP32_BF16 = 7, + FP16INT4_FP32_FP16 = 8, + BF16INT4_FP32_BF16 = 9, + COC_DATA_TYPE_DESC_MAX = 10, +}; + +const std::map COC_TYPE2ELE_SIZE = { + {FP16FP16_FP32_FP16, FP_BF_16_ELE_SIZE}, + {BF16BF16_FP32_BF16, FP_BF_16_ELE_SIZE}, + {INT8INT8_INT32_FP16, INT8_ELE_SIZE}, + {INT8INT8_INT32_BF16, INT8_ELE_SIZE}, + {FP16INT8_INT32_FP16, INT8_ELE_SIZE}, + {BF16INT8_INT32_BF16, INT8_ELE_SIZE}, + {FP16INT8_FP32_FP16, FP_BF_16_ELE_SIZE}, + {BF16INT8_FP32_BF16, FP_BF_16_ELE_SIZE}, + {FP16INT4_FP32_FP16, FP_BF_16_ELE_SIZE}, + {BF16INT4_FP32_BF16, FP_BF_16_ELE_SIZE} +}; + +struct MatMulInfo { + int64_t batchSize = 1; + int64_t m = -1; + int64_t n = -1; + int64_t k = -1; + bool transA = false; + bool transB = false; + bool withBias = false; + bool isInt8 = false; + bool weightNz = false; +}; + +struct WorkspaceDetail { + int64_t matrixActivationSize{0}; + int64_t matrixWeightSize{0}; + int64_t matrixIntermediateSize{0}; + int64_t formatDequantParamSize{0}; + + int64_t GetSize() const + { + return matrixActivationSize + matrixWeightSize + matrixIntermediateSize + formatDequantParamSize; + } +}; + +#endif \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp new file mode 100644 index 00000000..3e76f15f --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lib/matmul_intf.h" +#include +#include "matmul_allreduce_add_rmsnorm_aic_kernel.h" +#include "matmul_allreduce_add_rmsnorm_aiv_kernel.h" + +extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm( + GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, + GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out, GM_ADDR workspace, GM_ADDR tiling) +{ + REGISTER_TILING_DEFAULT(MatmulAllreduceAddRmsnormTilingData); + GET_TILING_DATA(tiling_data, tiling); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + Hccl hccl_; + auto tilingData = (__gm__ MatmulAllreduceAddRmsnormTilingData*)tiling; + __gm__ void* mc2InitTiling = (__gm__ void*)(&(tilingData->mc2InitTiling)); + __gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling)); + auto contextGM0 = AscendC::GetHcclContext(); + + if ASCEND_IS_AIC { + MatmulAllreduceAddRmsnormAicKernel op; + op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_); + op.Process(); + return; + } + + if ASCEND_IS_AIV { + MatmulAllreduceAddRmsnormAivKernel op; + + op.Init(x1, x2, residual, gamma, y, add_out, workspace, &tiling_data, hccl_); + op.Process(&tiling_data); + return; + } +} diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aic_kernel.h b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aic_kernel.h new file mode 100644 index 00000000..f3a0dc25 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aic_kernel.h @@ -0,0 +1,359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H + +#define ASCENDC_CUBE_ONLY + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/gemm/block/block_mmad.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/kernel/basic_matmul.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/layout/layout.hpp" + +#include "matmul_allreduce_add_rmsnorm_utils.h" +#include "matmul_allreduce_add_rmsnorm_tiling.h" + +constexpr int32_t SCALE_L1_SIZE_A = 256 * 8; +constexpr int32_t SCALE_L1_SIZE_B = 128 * 1024; +constexpr int32_t CUBE_MATRIX_SIZE_B16 = 256; // 16 * 16 +constexpr int32_t CUBE_MATRIX_SIZE_B8 = 16 * 32; // 16 * 32 +constexpr int32_t SCALE_L1_SIZE = 256 * 8; // 2 KB +constexpr int32_t BLOCK_SIZE_16 = 16; +constexpr int32_t BLOCK_SIZE_32 = 32; +constexpr int32_t DOUBLE_BUFFER_SIZE = 2; +constexpr uint32_t MM_L1_TILE_SHAPE_M = 128; +constexpr uint32_t MM_L1_TILE_SHAPE_N = 256; +constexpr uint32_t MM_L1_TILE_SHAPE_K = 256; +constexpr uint32_t MM_L0_TILE_SHAPE_M = MM_L1_TILE_SHAPE_M; +constexpr uint32_t MM_L0_TILE_SHAPE_N = MM_L1_TILE_SHAPE_N; +constexpr uint32_t MM_L0_TILE_SHAPE_K = 64; + +using namespace Catlass; + +template +struct GetAccumType { + using T = float; +}; + +__aicore__ inline bool IsQuant(const QuantGranularity &granularity) +{ + return (granularity > QuantGranularity::QUANT_GRANULARITY_UNDEFINED) && + (granularity < QuantGranularity::QUANT_GRANULARITY_MAX); +} + +template +class MatmulAllreduceAddRmsnormAicKernel { + using T_ACCUM = typename GetAccumType::T; +public: + int PIPE_DEPTH = 2; + Arch::Resource resource; + __aicore__ inline MatmulAllreduceAddRmsnormAicKernel() { } + + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y, + GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData* tilingData, + Hccl &hccl_) + { + this->hccl_ = hccl_; + this->gm_c = reinterpret_cast<__gm__ OutDtype *>(y); + + this->gm_dequant_scale = nullptr; + this->has_offset = false; + + auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData; + auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData; + auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo; + + this->batch_size = ppTilingData->opShape.batchSize; + this->m = ppTilingData->opShape.m; + this->k = ppTilingData->opShape.k; + this->n = ppTilingData->opShape.n; + this->weight_nz = false; + + this->is_int8 = false; + this->cube_matrix_size = this->is_int8 ? CUBE_MATRIX_SIZE_B8 : CUBE_MATRIX_SIZE_B16; + + this->m_align = Block512B::AlignUp(m); + this->k_align = Block512B::AlignUp(k); + this->n_align = Block512B::AlignUp(n); + + this->m0 = ppTilingData->m0; + this->k0 = ppTilingData->k0; + this->n0 = ppTilingData->n0; + + int32_t tiling_key = ppTilingData->tilingKey; + this->trans_a = ppTilingData->isTransA; + this->trans_b = ppTilingData->isTransB; + + int32_t aligned_a; + int32_t aligned_b; + this->dequant_granularity = quantInfo->dequantGranularity; + AlignJudge(this->trans_a, this->trans_b, this->m, this->k, this->n, + this->m_align, this->k_align, this->n_align, aligned_a, aligned_b); + this->aligned_a = aligned_a; + this->aligned_b = aligned_b; + if (weight_nz) { + this->k_align16 = Block32B::AlignUp(k); + this->n_align16 = Block32B::AlignUp(n); + } + bool has_a_align = IsQuant(quantInfo->quantGranularity) || aligned_a; + bool has_b_align = IsQuant(this->dequant_granularity) && !this->is_int8 || aligned_b; + bool has_accum = IsQuant(this->dequant_granularity) && + this->is_int8 && std::is_same::value; + bool has_format_dequant_offset = + (this->dequant_granularity == QuantGranularity::PER_TENSOR) && this->is_int8 && this->has_offset; + auto workspace_info = GetWorkspaceInfo(workspace, this->batch_size, this->m, this->k, this->n, + this->m_align, this->k_align, this->n_align, this->trans_a, this->trans_b, + sizeof(MmadDtype), has_a_align, has_b_align, has_accum, has_format_dequant_offset); + this->gm_a_src = reinterpret_cast<__gm__ MmadDtype *>(x1); + this->gm_b_src = reinterpret_cast<__gm__ MmadDtype *>(x2); + this->gm_format_dequant_offset = reinterpret_cast<__gm__ int32_t *>(has_format_dequant_offset ? + workspace_info.gm_dequant_param : nullptr); + this->gm_workspace_src = workspace; + this->block_size = BLOCK_SIZE_32 / sizeof(MmadDtype); + + int32_t a_l1_size = this->m0 * this->k0 * sizeof(MmadDtype); + int32_t a_l1_size_round = AscendC::DivCeil(a_l1_size, 512) * 512; + int32_t b_l1_size = this->n0 * this->k0 * sizeof(MmadDtype); + int32_t b_l1_size_round = AscendC::DivCeil(b_l1_size, 512) * 512; + this->l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t)(this->is_int8 ? SCALE_L1_SIZE : 0)); + this->l1_base_b = + reinterpret_cast<__cbuf__ MmadDtype *>(a_l1_size_round * (this->is_int8 ? DOUBLE_BUFFER_SIZE : 1) + + (uintptr_t) this->l1_base_a); + + this->core_num = get_block_num(); + this->core_idx = get_block_idx(); + + this->m_loop = ppTilingData->mLoop; + this->k_loop = ppTilingData->kLoop; + this->n_loop = ppTilingData->nLoop; + this->core_loop = ppTilingData->coreLoop; + this->swizzl_count = ppTilingData->swizzlCount; + this->swizzl_direct = ppTilingData->swizzlDirect; + this->is_91093 = commTilingData->is91093; + this->ping_flag = 1; + this->rank = hccl_.GetRankId(); + this->rank_size = hccl_.GetRankDim(); + this->withSerialMode = commTilingData->withSerialMode; + + this->gm_peer_mem = (__gm__ OutDtype *)hccl_.GetWindowsInAddr(this->rank); + } + + __aicore__ inline void MoveL0CToGM(__gm__ OutDtype *gm_dst, int64_t offset_c, + int32_t m_actual, int32_t n_actual, int32_t src_stride, int32_t dst_stride) { + if constexpr (std::is_same::value) { + copy_matrix_cc_to_gm( + gm_dst + offset_c, + l0c_buf, + 0, + n_actual, + m_actual, + dst_stride, + src_stride, + 0, + F322BF16, + 0, + false, + true + ); + } else { + copy_matrix_cc_to_gm( + gm_dst + offset_c, + l0c_buf, + 0, + n_actual, + m_actual, + dst_stride, + src_stride, + 0, + F322F16, + 0, + false, + true + ); + } + SetFlag(EVENT_ID0); + } + + __aicore__ inline void InitFlags() + { + WaitEvent(AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID); + } + + __aicore__ inline void Endflags() + { + } + + __aicore__ inline void Process() + { + // AIC matmul func, waits for AIV to complete [AllReduce & Add & RMSNorm]. + InitFlags(); + uint32_t m = this->m; + uint32_t k = this->k; + uint32_t n = this->n; + gmB.SetGlobalBuffer(gm_b_src, k * n); + + using LayoutA = layout::RowMajor; + using LayoutB = layout::ColumnMajor; + using LayoutC = layout::RowMajor; + LayoutB layoutB {(layout::ColumnMajor::Index)k, (layout::ColumnMajor::Index)n}; + + using L1TileShape = GemmShape; + using L0TileShape = GemmShape; + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = AType; + constexpr bool ENABLE_UNIT_FLAG = true; + using MmadDispatchPolicy = Gemm::MmadAtlasA2Pingpong; + using BlockMmad = Gemm::Block::BlockMmad; + GemmCoord blockShape = L1TileShape::ToCoord(); + + BlockMmad blockMmad(resource); + int mPerSplit = this->m0 * this->swizzl_count; + int mAvg = mPerSplit; + int splitM = AscendC::DivCeil(m, mPerSplit); + int flag_idx = 0; + icache_preload(8); // 8 corresponding to 16k + for (int splitIndex = 0; splitIndex < splitM; ++splitIndex) { + uint32_t mStart = splitIndex * mAvg; + uint32_t mActual = mAvg > (m - mStart) ? m - mStart:mAvg; + flag_idx = splitIndex % PIPE_DEPTH; + if (splitIndex >= PIPE_DEPTH) { + WaitEvent(flag_idx); + } + + __gm__ MmadDtype *gm_a_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_a_src) + mStart * k; + __gm__ MmadDtype *gm_c_src_tmp = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + mStart * n; + gmA.SetGlobalBuffer(gm_a_src_tmp, mActual*k); + gmC.SetGlobalBuffer(gm_c_src_tmp, mActual*n); + + GemmCoord splitShape{mActual, n, k}; + using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 1>; // SwizzleOffset=3 + BlockScheduler splitScheduler(splitShape, blockShape.GetCoordMN()); + uint32_t coreLoops = splitScheduler.GetCoreLoops(); + + LayoutA layoutA{mActual, k}; + LayoutC layoutC{mActual, n}; + + for (uint32_t loopIdx = core_idx; loopIdx < coreLoops; loopIdx += core_num) { + GemmCoord blockCoord = splitScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = splitScheduler.GetActualBlockShape(blockCoord); + GemmCoord offsetCoord = blockCoord * blockShape; + + MatrixCoord offsetA = offsetCoord.GetCoordMK(); + MatrixCoord offsetB = offsetCoord.GetCoordKN(); + MatrixCoord offsetC = offsetCoord.GetCoordMN(); + + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + blockMmad (gmA[gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], layoutC, actualBlockShape); + } + + FFTSCrossCoreSync(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx); + } + + Endflags(); + PipeBarrier(); + } + +private: + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmC; + __gm__ MmadDtype *gm_a_src{nullptr}; + __gm__ MmadDtype *gm_b_src{nullptr}; + + __gm__ OutDtype *gm_c{nullptr}; + __gm__ OutDtype *gm_peer_mem{nullptr}; + __gm__ int64_t *gm_dequant_scale{nullptr}; + __gm__ int32_t *gm_format_dequant_offset{nullptr}; + __gm__ int32_t *gm_accum{nullptr}; + __gm__ uint8_t *gm_workspace_src; + + __cbuf__ MmadDtype *l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_A); + __cbuf__ MmadDtype *l1_base_b = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE_B); + + __ca__ MmadDtype *l0a_base = reinterpret_cast<__ca__ MmadDtype *>((uintptr_t) 0); + __cb__ MmadDtype *l0b_base = reinterpret_cast<__cb__ MmadDtype *>((uintptr_t) 0); + + __cc__ T_ACCUM *l0c_buf = reinterpret_cast<__cc__ T_ACCUM *>((uintptr_t) 0); + + __cbuf__ int64_t *scale_l1 = reinterpret_cast<__cbuf__ int64_t *>((uintptr_t) 0); + __fbuf__ int64_t *scale_FB = (__fbuf__ int64_t *)(0); + + __cbuf__ int32_t *bias_l1 = reinterpret_cast<__cbuf__ int32_t *>((uintptr_t)0); + uint16_t bias_bt = 0; + bool has_offset{false}; + + int32_t core_num; + + int32_t batch_size; + int32_t m; + int32_t k; + int32_t n; + int32_t m_align; + int32_t k_align; + int32_t n_align; + int32_t k_align16; + int32_t n_align16; + int32_t m0; + int32_t k0; + int32_t n0; + + int32_t m_loop; + int32_t n_loop; + int32_t k_loop; + int32_t core_loop; + int32_t core_idx; + int32_t ping_flag; + int32_t block_size; + int32_t cube_matrix_size; + + int32_t aligned_a; + int32_t aligned_b; + + int32_t swizzl_count; + int32_t swizzl_direct; + + int32_t rank; + int32_t rank_size; + + int32_t withSerialMode; + + int32_t ag_dim; + int32_t rs_dim; + bool inner_dim_is_Ag{false}; + int32_t ag_rank_idx; + int32_t rs_rank_idx; + bool weight_nz{false}; + + bool is_91093{false}; + QuantGranularity dequant_granularity; + + bool is_int8; + bool trans_a; + bool trans_b; + + Hccl hccl_; +}; + +#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_AIC_KERNEL_H diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aiv_kernel.h b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aiv_kernel.h new file mode 100644 index 00000000..c76d39eb --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_aiv_kernel.h @@ -0,0 +1,702 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_AIV_KERNEL_H + +#include "kernel_operator.h" +#include "matmul_allreduce_add_rmsnorm_tiling.h" +#include "matmul_allreduce_add_rmsnorm_utils.h" + +using namespace AscendC; + +constexpr int32_t DIFUSION_ADD_LEN = 512; +constexpr int32_t TQUE_DEPTH = 1; +constexpr uint32_t TBUF_POOL_MAX_BUFID_SIZE = 8; +enum CrossRankSyncFlagEnum { + FLAG_ZERO_IDX, + FLAG_ONE_IDX, + FLAG_TWO_IDX, + FLAG_ADD_IDX, + FLAG_FOUR_IDX, + FLAG_GATHER_ADD_OUT_STEP1, + FLAG_GATHER_ADD_OUT_STEP2, + FLAG_NUM +}; +constexpr int32_t FLAG_VALUE = 1; +constexpr int32_t NUM_PER_REP_FP32 = 64; + +template +__aicore__ void CopyUbufToGmAlignB16(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint16_t srcSTride, uint16_t dstStride) +{ + DataCopyExtParams dataCopyParams(nBurst, + lenBurst, + srcSTride, + dstStride, + 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(src); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(dst)); + DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +template +__aicore__ void CopyGmToUbufAlignB16(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint16_t srcSTride, uint16_t dstStride) +{ + DataCopyExtParams dataCopyParams(nBurst, + lenBurst, + srcSTride, + dstStride, + 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(dst); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(src)); + DataCopyPadExtParams padParams; + DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +template +class MatmulAllreduceAddRmsnormAivKernel { + +public: + __aicore__ inline MatmulAllreduceAddRmsnormAivKernel() { } + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR residual, GM_ADDR gamma, GM_ADDR y, GM_ADDR add_out, + GM_ADDR workspace, const MatmulAllreduceAddRmsnormTilingData *tilingData, + Hccl &hccl_) + { + this->hccl_ = hccl_; + is_deterministic = false; + auto ppTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData; + auto commTilingData = &tilingData->matmulAllreduceAddRmsnormInfo.commTilingData; + auto quantInfo = &tilingData->matmulAllreduceAddRmsnormInfo.quantInfo; + + gm_out = reinterpret_cast<__gm__ MmadDtype *>(y); + gm_add_input = reinterpret_cast<__gm__ MmadDtype *>(residual); + gm_add_output = reinterpret_cast<__gm__ MmadDtype *>(add_out); + gm_gamma = reinterpret_cast<__gm__ MmadDtype *>(gamma); + + batch_size = ppTilingData->opShape.batchSize; + m = ppTilingData->opShape.m; + k = ppTilingData->opShape.k; + n = ppTilingData->opShape.n; + + m0 = ppTilingData->m0; + k0 = ppTilingData->k0; + n0 = ppTilingData->n0; + + m_loop = ppTilingData->mLoop; + k_loop = ppTilingData->kLoop; + n_loop = ppTilingData->nLoop; + + core_loop = ppTilingData->coreLoop; + swizzl_count = ppTilingData->swizzlCount; + tiling_key = ppTilingData->tilingKey; + rank = hccl_.GetRankId(); + rank_size = hccl_.GetRankDim(); + + max_ub_single_dma_size = commTilingData->ubMoveNum; + withSerialMode = false; + tag = commTilingData->tag; + comm_npu_split = commTilingData->commNpuSplit; + comm_data_split = commTilingData->commDataSplit; + comm_direct = commTilingData->commDirect; + is_91093 = false; + core_count = comm_npu_split * comm_data_split; + dequant_granularity = static_cast(quantInfo->dequantGranularity); + dequant_group_size = quantInfo->dequantGroupSize; + quant_granularity = static_cast(quantInfo->quantGranularity); + quant_group_size = quantInfo->quantGroupSize; + epsilon = tilingData->matmulAllreduceAddRmsnormInfo.rmsnormTilingData.epsilon; + is_gather_add_out = tilingData->matmulAllreduceAddRmsnormInfo.ppTilingData.isGatherAddOut; + + swizzl_direct = (tiling_key & SWIZZL_MASK) ? true : false; + trans_a = ppTilingData->isTransA; + trans_b = ppTilingData->isTransB; + is_int8 = false; + ag_dim = 0; + rs_dim = 0; + inner_dim_is_Ag = false; + weight_nz = false; + max_ub_ping_pong_size = max_ub_single_dma_size / 2; // 2 - double buffer + + core_idx = get_block_idx(); + core_num = get_block_num(); + aiv_idx = get_subblockid(); + other_rank = (core_idx < rank_size) ? core_idx : -1; + + // init ub usage + pipe.InitBuffer(ctrlBuf, AscendC::ONE_BLK_SIZE); + ub_ctrl_flag = reinterpret_cast<__ubuf__ int32_t *>(ctrlBuf.Get().GetPhyAddr()); + + pipe.InitBuffer(gammaBuf, n * sizeof(MmadDtype)); + + uint32_t step1_ub_usage = AscendC::AlignUp( + n * sizeof(MmadDtype) + + 2 * (rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype)) + + n * sizeof(MmadDtype) + + n * sizeof(MmadDtype) + + n * sizeof(float) + + n * sizeof(float) + + n * sizeof(float), + AscendC::ONE_BLK_SIZE); + + uint32_t step2_ub_usage = AscendC::AlignUp( + max_ub_ping_pong_size * sizeof(MmadDtype), + AscendC::ONE_BLK_SIZE) * 2; + uint32_t max_step_ub_usage = max(step1_ub_usage, step2_ub_usage); + + pipe.InitBufPool(step1BufPool, max_step_ub_usage); + pipe.InitBufPool(step2BufPool, max_step_ub_usage, step1BufPool); + + step1BufPool.InitBuffer(inQueueX, 1, n * sizeof(MmadDtype)); + step1BufPool.InitBuffer(inQueueY, 2, rank_size * DIFUSION_ADD_LEN * sizeof(MmadDtype)); + step1BufPool.InitBuffer(addOutQueue, 1, n * sizeof(MmadDtype)); + step1BufPool.InitBuffer(outQueue, 1, n * sizeof(MmadDtype)); + step1BufPool.InitBuffer(xFp32Buf, n * sizeof(float)); + step1BufPool.InitBuffer(sqxBuf, n * sizeof(float)); + step1BufPool.InitBuffer(reduceFp32Buf, n * sizeof(float)); + + step2BufPool.InitBuffer(allgatherBuf[0], max_ub_ping_pong_size * sizeof(MmadDtype)); + step2BufPool.InitBuffer(allgatherBuf[1], max_ub_ping_pong_size * sizeof(MmadDtype)); + + CopyInGamma(); + } + + __aicore__ inline void Process(const MatmulAllreduceAddRmsnormTilingData *tilingData) + { + // AIV AllReduce & Add & RMSNorm func, waits for AIC to complete [Matmul]. + FFTSCrossCoreSync(FFTS_SYNC_AICORE_GROUP_MODE, AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID); + PipeBarrier(); + + ResetIpcFlags(FLAG_NUM); + CrossRankSyncEx(FLAG_NUM); + constexpr int32_t allreduce_used_core = 16; + int32_t one_comm_count = swizzl_count; + int32_t loop_num_per_comm = one_comm_count * n_loop; + int32_t comm_count = DivCeil(core_loop, loop_num_per_comm); + int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + + for (int cal_idx = 0; cal_idx < comm_count; ++cal_idx) { + uint64_t flag_idx = cal_idx % pipe_depth; + int32_t m_total = (cal_idx == comm_count - 1) ? + m - cal_idx * swizzl_count * m0 : swizzl_count * m0; + int32_t m_per_rank = DivCeil(m_total, rank_size); + int32_t loop_offset = cal_idx * swizzl_count * m0; + + WaitEvent(flag_idx); + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + + if (aiv_idx == 0 && core_idx < allreduce_used_core) { + int32_t m_cur_rank = LimitRange(m_total - rank * m_per_rank, 0, m_per_rank); + int32_t m_per_core = DivCeil(m_cur_rank, allreduce_used_core); + int32_t m_cur_core = LimitRange(m_cur_rank - core_idx * m_per_core, 0, m_per_core); + int32_t core_offset_m = loop_offset + rank * m_per_rank + core_idx * m_per_core; + ParallelWithSplitStepOneAddNorm(core_offset_m * n, m_cur_core); + } + + PipeBarrier(); + + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + CrossRankSyncV1(FLAG_ADD_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + + { // ParallelWithSplitStepTwo + int32_t used_core_per_rank = allreduce_used_core / rank_size; + int32_t sub_core_idx = core_idx % used_core_per_rank; + int32_t gather_rank_id = core_idx / used_core_per_rank; + int32_t m_in_rank = LimitRange(m_total - gather_rank_id * m_per_rank, 0, m_per_rank); + int32_t m_per_core = DivCeil(m_in_rank, used_core_per_rank); + int32_t m_cur_core = LimitRange(m_in_rank - sub_core_idx * m_per_core, 0, m_per_core); + int32_t core_offset_m = loop_offset + gather_rank_id * m_per_rank + sub_core_idx * m_per_core; + auto gm_share_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(gather_rank_id); + + bool filter_core_cond = aiv_idx == 0 && core_idx < allreduce_used_core && m_cur_core > 0; + if (filter_core_cond) { + ParallelAllGather(gm_out, gm_share_buff, core_offset_m * n, m_cur_core * n); + } + + SetAndWaitAivSync(flag_idx); + CrossRankSyncV2(FLAG_TWO_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx); + + if (is_gather_add_out) { + if (filter_core_cond && gather_rank_id == rank) { + ParallelAllGather(gm_share_buff, gm_add_output, core_offset_m * n, m_cur_core * n); + } + + SetAndWaitAivSync(flag_idx); + CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP1, cal_idx + 1); + SetAndWaitAivSync(flag_idx); + + if (filter_core_cond && gather_rank_id != rank) { + ParallelAllGather(gm_add_output, gm_share_buff, core_offset_m * n, m_cur_core * n); + } + + SetAndWaitAivSync(flag_idx); + CrossRankSyncV2(FLAG_GATHER_ADD_OUT_STEP2, cal_idx + 1); + SetAndWaitAivSync(flag_idx); + } + } + + if (cal_idx <= comm_count - pipe_depth) { + SetAicSync(flag_idx); + } + } + ResetIpcFlags(FLAG_NUM); + if (aiv_idx == 0 && core_idx < rank_size) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(other_rank); + CheckBuffFlag(ub_ctrl_flag, state_buff + FLAG_ZERO_IDX, 0); + } + } + +private: + __aicore__ void SetBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag) + { + *ub_ctrl_flag = flag; + SetFlag(EVENT_ID2); + WaitFlag(EVENT_ID2); + CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0); + } + + __aicore__ void SetBuffFlagByAdd(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag) + { + PipeBarrier(); + *ub_ctrl_flag = flag; + PipeBarrier(); + SetAtomicAdd(); + PipeBarrier(); + CopyUbufToGmAlignB16(buff, ub_ctrl_flag, 1, sizeof(int32_t), 0, 0); + PipeBarrier(); + SetAtomicNone(); + PipeBarrier(); + } + + __aicore__ void CheckBuffFlag(__ubuf__ int32_t *ub_ctrl_flag, __gm__ int32_t *buff, int32_t flag) + { + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + while (true) { + CopyGmToUbufAlignB16(ub_ctrl_flag, buff, 1, sizeof(int32_t), 0, 0); + SetFlag(EVENT_ID3); + WaitFlag(EVENT_ID3); + if (*ub_ctrl_flag == flag) { + break; + } + } + } + + __aicore__ void SetAicSync(uint64_t flag_idx) + { + FFTSCrossCoreSync(FFTS_SYNC_AICORE_GROUP_MODE, flag_idx); + } + + __aicore__ void ResetIpcFlags(int32_t num_flags) + { + for (int32_t idx = 0; idx <= num_flags; ++idx) { + if (core_idx == 0 && aiv_idx == 0) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank); + SetBuffFlag(ub_ctrl_flag, state_buff + idx, 0); + } + } + } + + __aicore__ void CrossRankSyncV1(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx == rank) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank); + SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE); + } else if (aiv_idx == 0 && core_idx < rank_size) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx); + CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * flag_data); + } + } + + __aicore__ void CrossRankSyncV2(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx < rank_size) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(core_idx); + SetBuffFlagByAdd(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE); + } + if (aiv_idx == 0 && core_idx == rank) { + __gm__ int32_t *state_buff = (__gm__ int32_t *)hccl_.GetWindowsOutAddr(rank); + CheckBuffFlag(ub_ctrl_flag, state_buff + flag_idx, FLAG_VALUE * rank_size * flag_data); + } + } + + __aicore__ void SetAndWaitAivSync(uint64_t flag_idx, int32_t pipe_depth = 2) + { + FFTSCrossCoreSync(0, flag_idx + pipe_depth); + WaitEvent(flag_idx + pipe_depth); + } + + __aicore__ inline uint32_t GetGmU32(GM_ADDR gm_addr) + { + copy_gm_to_ubuf_align_b32(ub_ctrl_flag, gm_addr, 0, 1, sizeof(uint32_t), 0, 0, 0, 0); + PipeSync(); + return *reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag); + } + + __aicore__ inline void SetGmU32(GM_ADDR gm_addr, uint32_t data) + { + *reinterpret_cast<__ubuf__ uint32_t *>(ub_ctrl_flag) = data; + PipeSync(); + copy_ubuf_to_gm_align_b32(gm_addr, ub_ctrl_flag, 0, 1, sizeof(uint32_t), 0, 0, 0, 0); + } + + __aicore__ inline void CrossRankSyncEx(uint32_t flag_idx) + { + AscendC::SyncAll(); + __asm__ __volatile__(""); + if (aiv_idx == 0 && core_idx == 0) { + auto flag_addr = (GM_ADDR)hccl_.GetWindowsOutAddr(0) + flag_idx * AscendC::ONE_BLK_SIZE; + uint32_t old_flag_data = GetGmU32(flag_addr); + __asm__ __volatile__(""); + SetAtomicAdd(); + SetGmU32(flag_addr, 1); + PipeSync(); + SetAtomicNone(); + __asm__ __volatile__(""); + + uint32_t new_flag_data; + do { + new_flag_data = GetGmU32(flag_addr); + __asm__ __volatile__(""); + } while (new_flag_data - old_flag_data < rank_size); + __asm__ __volatile__(""); + SetAtomicAdd(); + SetGmU32(flag_addr, 1); + PipeSync(); + SetAtomicNone(); + } + __asm__ __volatile__(""); + AscendC::SyncAll(); + } + + template + __aicore__ inline T min(const T& a, const T& b) { + return (a < b) ? a : b; + } + + template + __aicore__ inline T max(const T& a, const T& b) { + return (a > b) ? a : b; + } + + template + __aicore__ inline T LimitRange(const T& val, const T& low, const T& high) { + return min(max(val, low), high); + } + + template + __aicore__ inline void PipeSync() + { + AscendC::TEventID event_id = static_cast(GetTPipePtr()->FetchEventID(EVENT)); + AscendC::SetFlag(event_id); + AscendC::WaitFlag(event_id); + } + + __aicore__ inline void CopyInGamma() + { + GlobalTensor gamma_global; + gamma_global.SetGlobalBuffer((__gm__ MmadDtype *)gm_gamma, n); + DataCopy(gammaBuf.Get(), gamma_global, n); + PipeSync(); + } + + __aicore__ void ParallelWithSplitStepOneAddNorm(uint32_t core_buf_offset, uint32_t m_cur_core) + { + if (m_cur_core <= 0) { + return; + } + + auto buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(rank); + + GlobalTensor x_global; + GlobalTensor y_global; + GlobalTensor out_global; + GlobalTensor add_out_global; + + x_global.SetGlobalBuffer(buff + core_buf_offset); + out_global.SetGlobalBuffer(buff + core_buf_offset); + add_out_global.SetGlobalBuffer(gm_add_output + core_buf_offset); + + uint32_t add_count = DivCeil(n, DIFUSION_ADD_LEN); + + LocalTensor x_local; + LocalTensor y_local; + + for (uint32_t i = 0; i < m_cur_core; i++) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + + x_local = inQueueX.AllocTensor(); + for (uint32_t j = 0; j < add_count; j++) { + uint32_t add_offset = j * DIFUSION_ADD_LEN; + uint32_t add_len = min(n - add_offset, DIFUSION_ADD_LEN); + + DataCopy(x_local[add_offset], x_global[i * n + add_offset], add_len); + inQueueX.EnQue(x_local); + + uint32_t iterate_end = (rank + 1) % rank_size; + y_local = inQueueY.AllocTensor(); + for (uint32_t k = 0; k < rank_size; ++k) { + uint32_t iterate_idx = iterate_end + k; + if (iterate_idx >= rank_size) { + iterate_idx -= rank_size; + } + + if (iterate_idx == rank) { + y_global.SetGlobalBuffer(gm_add_input + core_buf_offset); + } else { + auto other_buff = (__gm__ MmadDtype *)hccl_.GetWindowsInAddr(iterate_idx); + y_global.SetGlobalBuffer(other_buff + core_buf_offset); + } + DataCopy(y_local[k * add_len], y_global[i * n + add_offset], add_len); + } + inQueueY.EnQue(y_local); + x_local = inQueueX.DeQue(); + y_local = inQueueY.DeQue(); + + Cast(x_fp32[add_offset], x_local[add_offset], RoundMode::CAST_NONE, add_len); + PipeBarrier(); + for (uint32_t k = 0; k < rank_size; ++k) { + // use sqx as shared buf, required n >= add_len + Cast(sqx, y_local[k * add_len], RoundMode::CAST_NONE, add_len); + PipeBarrier(); + Add(x_fp32[add_offset], x_fp32[add_offset], sqx, add_len); + PipeBarrier(); + } + + inQueueY.FreeTensor(y_local); + } + inQueueX.FreeTensor(x_local); + + // copy add result out + LocalTensor add_out = addOutQueue.AllocTensor(); + Cast(add_out, x_fp32, RoundMode::CAST_RINT, n); + addOutQueue.EnQue(add_out); + add_out = addOutQueue.DeQue(); + DataCopy(add_out_global[i * n], add_out, n); + addOutQueue.FreeTensor(add_out); + + LocalTensor gamma_local = gammaBuf.Get(); + LocalTensor out_local = outQueue.AllocTensor(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + + // make sure precision is same in bf16 case + Cast(out_local, x_fp32, RoundMode::CAST_RINT, n); + PipeBarrier(); + + Cast(x_fp32, out_local, RoundMode::CAST_NONE, n); + PipeBarrier(); + + Mul(sqx, x_fp32, x_fp32, n); + PipeBarrier(); + + Muls(sqx, sqx, (float)1.0 / n, n); + PipeBarrier(); + + ReduceSum(sqx, sqx, reduce_buf_local, n); + PipeBarrier(); + + Adds(sqx, sqx, epsilon, 1); + PipeBarrier(); + + Sqrt(sqx, sqx, 1); + Duplicate(reduce_buf_local, (float)1.0, 1); + PipeBarrier(); + + Div(sqx, reduce_buf_local, sqx, 1); + PipeBarrier(); + + PipeSync(); + float rstd_value = sqx.GetValue(0); + PipeSync(); + PipeBarrier(); + + Muls(x_fp32, x_fp32, rstd_value, n); + PipeBarrier(); + + if constexpr (std::is_same::value) { + Cast(out_local, x_fp32, RoundMode::CAST_NONE, n); + PipeBarrier(); + Mul(out_local, gamma_local, out_local, n); + PipeBarrier(); + } else if constexpr (std::is_same::value) { + Cast(out_local, x_fp32, RoundMode::CAST_RINT, n); + PipeBarrier(); + Cast(x_fp32, out_local, RoundMode::CAST_NONE, n); + PipeBarrier(); + Cast(sqx, gamma_local, RoundMode::CAST_NONE, n); + PipeBarrier(); + + Mul(x_fp32, x_fp32, sqx, n); + PipeBarrier(); + Cast(out_local, x_fp32, RoundMode::CAST_RINT, n); + PipeBarrier(); + PipeSync(); + } + + outQueue.EnQue(out_local); + out_local = outQueue.DeQue(); + DataCopy(out_global[i * n], out_local, n); + outQueue.FreeTensor(out_local); + } + } + + __aicore__ void ParallelAllGather(__gm__ MmadDtype *gm_dst, __gm__ MmadDtype *gm_src, + uint32_t core_buf_offset, uint32_t data_len) + { + GlobalTensor src_global; + GlobalTensor dst_global; + src_global.SetGlobalBuffer(gm_src); + dst_global.SetGlobalBuffer(gm_dst); + + constexpr uint32_t PIPELINE_COPY_NUM = sizeof(allgatherBuf) / sizeof(allgatherBuf[0]); + TEventID ev_mte3_mte2[PIPELINE_COPY_NUM]; + TEventID ev_mte2_mte3[PIPELINE_COPY_NUM]; + LocalTensor local_tensors[PIPELINE_COPY_NUM]; + + for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) { + ev_mte3_mte2[i] = GetTPipePtr()->AllocEventID(); + ev_mte2_mte3[i] = GetTPipePtr()->AllocEventID(); + SetFlag(ev_mte3_mte2[i]); + local_tensors[i] = allgatherBuf[i].Get(); + } + + uint32_t offset = core_buf_offset; + uint32_t copy_len = max_ub_ping_pong_size; // num of MmadDtype, not the byte length + uint32_t copy_count = DivCeil(data_len, copy_len); + uint32_t pipe_id = 0; + + for (uint32_t i = 0; i < copy_count; i++) { + uint32_t actual_copy_len = + (i == copy_count - 1) ? (data_len - i * copy_len) : copy_len; + + auto &local_tensor = local_tensors[pipe_id]; + + WaitFlag(ev_mte3_mte2[pipe_id]); + DataCopy(local_tensor, src_global[offset], actual_copy_len); + SetFlag(ev_mte2_mte3[pipe_id]); + WaitFlag(ev_mte2_mte3[pipe_id]); + DataCopy(dst_global[offset], local_tensor, actual_copy_len); + SetFlag(ev_mte3_mte2[pipe_id]); + + offset += actual_copy_len; + pipe_id = (pipe_id + 1) % PIPELINE_COPY_NUM; + } + + for (uint32_t i = 0; i < PIPELINE_COPY_NUM; i++) { + WaitFlag(ev_mte3_mte2[i]); + GetTPipePtr()->ReleaseEventID(ev_mte3_mte2[i]); + GetTPipePtr()->ReleaseEventID(ev_mte2_mte3[i]); + } + + PipeBarrier(); + } + + __gm__ MmadDtype *gm_out; + __gm__ MmadDtype *gm_add_input; + __gm__ MmadDtype *gm_add_output; + __gm__ MmadDtype *gm_gamma; + __ubuf__ int32_t *ub_ctrl_flag; + + int32_t batch_size; + int32_t m; + int32_t k; + int32_t n; + int32_t m0; + int32_t k0; + int32_t n0; + + int32_t m_loop; + int32_t n_loop; + int32_t k_loop; + int32_t core_loop; + int32_t core_idx; + + int32_t rank; + int32_t rank_size; + int32_t tiling_key; + int32_t swizzl_count; + bool swizzl_direct; + + bool trans_a; + bool trans_b; + bool is_int8; + bool is_91093; + bool is_gather_add_out; + + int32_t aiv_idx; + int32_t other_rank; + int32_t core_num; + int32_t max_ub_single_dma_size; + int32_t max_ub_ping_pong_size; + + int32_t gm_c_pingpong_size; + int32_t withSerialMode; + int32_t tag; + int32_t comm_npu_split; + int32_t comm_data_split; + int32_t comm_direct; + + int32_t core_count; + bool is_deterministic; + + QuantGranularity dequant_granularity; + int32_t dequant_group_size; + QuantGranularity quant_granularity; + int32_t quant_group_size; + + WorkspaceInfo workspace_info; + int32_t ag_dim; + int32_t rs_dim; + bool inner_dim_is_Ag; + bool weight_nz{false}; + + float epsilon; + + TPipe pipe; + AscendC::TBufPool step1BufPool; + AscendC::TBufPool step2BufPool; + + AscendC::TQue inQueueX, inQueueY; + AscendC::TQue outQueueZ; + AscendC::TQue addOutQueue; + AscendC::TQue outQueue; + + AscendC::TBuf ctrlBuf; + AscendC::TBuf gammaBuf; + AscendC::TBuf xFp32Buf; + AscendC::TBuf sqxBuf; + AscendC::TBuf reduceFp32Buf; + AscendC::TBuf allgatherBuf[2]; + + Hccl hccl_; +}; +#endif \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_tiling.h b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_tiling.h new file mode 100644 index 00000000..88a4401e --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_tiling.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H + +#include +#include "kernel_tiling/kernel_tiling.h" + +enum QuantGranularity : int { + QUANT_GRANULARITY_UNDEFINED = -1, + PER_TENSOR = 0, + PER_CHANNEL = 1, + PER_GROUP = 2, + QUANT_GRANULARITY_MAX = 3, +}; + +struct Opshape { + int32_t batchSize = 1; + int32_t m = -1; + int32_t k = -1; + int32_t n = -1; +}; + +struct PPTilingData { + Opshape opShape = {}; + int32_t m0 = 1; + int32_t k0 = 1; + int32_t n0 = 1; + int32_t mLoop = 1; + int32_t kLoop = 1; + int32_t nLoop = 1; + int32_t coreLoop = 1; + int32_t swizzlCount = 1; + int32_t swizzlDirect = 0; + uint32_t tilingKey = 0; + int32_t blockDim = 1; + int32_t splitK = 0; + bool weightNz = false; + bool isTransA = false; + bool isTransB = false; + bool isGatherAddOut = false; +}; + +struct CommTilingData { + int32_t rank = 1; + int32_t rankSize = 1; + int32_t pValue = 1; + int32_t ubMoveNum = 1; + int32_t write2OtherRank = 0; + int32_t withSerialMode = 0; + int32_t tag = 0; + int32_t commNpuSplit = 1; + int32_t commDataSplit = 1; + int32_t commDirect = 0; + int32_t lenPerLoop = 1; + int32_t is91093 = 0; + int32_t buffer_size = 0; +}; + +struct RmsNormTilingData { + RmsNormTiling tiling{}; + uint32_t loopCount; + uint32_t calcBytes; + float epsilon{}; +}; + +struct QuantInfo { + QuantGranularity dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + int32_t dequantGroupSize = -1; + QuantGranularity quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + int32_t quantGroupSize = -1; +}; + +struct MatmulAllreduceAddRmsnormInfo { + PPTilingData ppTilingData{}; + CommTilingData commTilingData{}; + RmsNormTilingData rmsnormTilingData{}; + QuantInfo quantInfo{}; +}; + +struct MatmulAllreduceAddRmsnormTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + MatmulAllreduceAddRmsnormInfo matmulAllreduceAddRmsnormInfo; +}; + +#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_TILING_H \ No newline at end of file diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_utils.h b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_utils.h new file mode 100644 index 00000000..9f2e8897 --- /dev/null +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm_utils.h @@ -0,0 +1,414 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H +#define MATMUL_ALLREDUCE_ADD_RMSNORM_UTILS_H + +#include +#include "kernel_operator.h" +using namespace AscendC; + +constexpr int64_t ND2NZ_STRIDE_LIMIT = 65536; +constexpr int32_t AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID = 12; +constexpr int32_t MAX_BLOCK_COUNT = 2; +constexpr int32_t BLOCK_COUNT_3 = 3; +constexpr int32_t BLOCK_COUNT_4 = 4; +constexpr int32_t TILE_BLOCK_MOD = 2; + +constexpr int32_t BLOCK_SIZE_32B = 32; +constexpr int32_t BLOCK_SIZE_256B = 256; +constexpr int32_t BLOCK_SIZE_512B = 512; + +constexpr int32_t FFTS_SYNC_INTERNEL_MODE = 0; +constexpr int32_t FFTS_SYNC_AICORE_GROUP_MODE = 2; + +constexpr int32_t SWIZZL_MASK = 0b100000; +constexpr int32_t TRANS_A_MASK = 0b010000; +constexpr int32_t TRANS_B_MASK = 0b001000; +constexpr int32_t INT8_MASK = 0b000100; +constexpr int32_t BIAS_MASK = 0b000010; + +template +struct BaseBlock { + static_assert((SIZE & (SIZE - 1)) == 0, "Invalid block size"); + static constexpr size_t size = SIZE / sizeof(T); + + static __aicore__ inline size_t Count(size_t len) + { + return (len + size - 1) / size; + } + + static __aicore__ inline bool IsAligned(size_t len) + { + return len % size == 0; + } + + static __aicore__ inline size_t AlignUp(size_t len) + { + return (len + size - 1) & ~(size - 1); + } + + static __aicore__ inline size_t AlignDown(size_t len) + { + return len & ~(size - 1); + } +}; + +template +using Block32B = BaseBlock; + +template +using Block256B = BaseBlock; + +template +using Block512B = BaseBlock; + +struct WorkspaceInfo { + __gm__ uint8_t *gm_a_align{ nullptr }; + __gm__ uint8_t *gm_b_align{ nullptr }; + __gm__ uint8_t *gm_accum{ nullptr }; + __gm__ uint8_t *gm_dequant_param{ nullptr }; +}; + +template +__aicore__ inline LocalTensor CreateLocalTensor(__ubuf__ T *addr) +{ + LocalTensor tensor; + TBuffAddr taddr; + taddr.bufferAddr = reinterpret_cast(addr); + tensor.SetAddr(taddr); + return tensor; +} + +template +__aicore__ inline LocalTensor CreateLocalTensor(uint32_t buffer_offset) +{ + LocalTensor tensor; + tensor.address_.bufferAddr = buffer_offset; + return tensor; +} + +template +__aicore__ inline LocalTensor CreateLocalTensor(uint32_t buffer_offset, uint8_t logic_pos) +{ + LocalTensor tensor; + tensor.address_.logicPos = logic_pos; + tensor.address_.bufferAddr = buffer_offset; + return tensor; +} + +template +struct IntrinsicCopyGmToL1Nd2Nz { + static __aicore__ inline void move( + __cbuf__ T *dst, __gm__ T *src, + uint8_t sid, uint16_t ndNum, uint16_t nValue, uint16_t dValue, + uint16_t srcNdMatrixStride, uint16_t srcDValue, uint16_t dstNzC0Stride, + uint16_t dstNzNStride, uint16_t dstNzMatrixStride) { + Nd2NzParams nd2nzParams( + ndNum, nValue, dValue, + srcNdMatrixStride, srcDValue, dstNzC0Stride, + dstNzNStride, dstNzMatrixStride + ); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t dst_logicpos = static_cast(TPosition::C1); + LocalTensor dstTensor; + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + GlobalTensor srcTensor; + srcTensor.SetGlobalBuffer(src); + DataCopy(dstTensor, srcTensor, nd2nzParams); + } +}; + +template +struct CopyGmToL1Nd2zN { + static __aicore__ inline void move( + __cbuf__ T *dst, __gm__ T *src, + uint16_t nValue, uint16_t dValue, uint32_t srcDValue, uint16_t dstNzC0Stride) { + constexpr int BLOCK_LEN = 32 / sizeof(T); + if (srcDValue < ND2NZ_STRIDE_LIMIT) { + IntrinsicCopyGmToL1Nd2Nz::move( + dst, + src, + 0, + 1, + nValue, + dValue, + 0, + srcDValue, + dstNzC0Stride, + 1, + 0 + ); + } else { + for (int i = 0; i < nValue; i++) { + IntrinsicCopyGmToL1Nd2Nz::move( + dst + i * BLOCK_LEN, + src + i * srcDValue, + 0, + 1, + 1, + dValue, + 0, + 0, + dstNzC0Stride, + 0, + 0 + ); + } + } + } +}; + +__aicore__ inline void AlignJudge(bool trans_a, bool trans_b, int32_t m, int32_t k, int32_t n, int32_t m_align, + int32_t k_align, int32_t n_align, int32_t &aligned_a, int32_t &aligned_b) +{ + if (!trans_a) { + aligned_a = k != k_align; + } else { + aligned_a = (m != m_align && m != 1); + } + + if (!trans_b) { + aligned_b = (n != n_align); + } else { + aligned_b = (k != k_align); + } +} + +__aicore__ inline WorkspaceInfo GetWorkspaceInfo(__gm__ uint8_t *gm_workspace, int32_t batch_size, int32_t m, + int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool trans_a, bool trans_b, + int32_t mmad_dsize, bool has_a_align, bool has_b_align, bool has_accum = false, bool has_dequant_param = false) +{ + WorkspaceInfo workspace_info; + uint64_t workspace_offset = 0; + + if (has_a_align) { + workspace_info.gm_a_align = gm_workspace + workspace_offset; + workspace_offset += static_cast(batch_size) * (trans_a ? k * m_align : m * k_align) * mmad_dsize; + } + + if (has_b_align) { + workspace_info.gm_b_align = gm_workspace + workspace_offset; + workspace_offset += static_cast(batch_size) * (trans_b ? n * k_align : k * n_align) * mmad_dsize; + } + + if (has_accum) { + workspace_info.gm_accum = gm_workspace + workspace_offset; + workspace_offset += static_cast(batch_size) * m * n * sizeof(int32_t); + } + + if (has_dequant_param) { + workspace_info.gm_dequant_param = gm_workspace + workspace_offset; + workspace_offset += n * sizeof(float32_t); + } + + return workspace_info; +} + + +template +__aicore__ inline void CopyCubfToBt(uint64_t dst, __cbuf__ T *src, uint16_t convControl, uint16_t nBurst, + uint16_t lenBurst, uint16_t sourceGap, uint16_t dstGap) +{ + DataCopyParams intriParams(nBurst, lenBurst, sourceGap, dstGap); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); + uint8_t dst_logicpos = static_cast(TPosition::C2); + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + DataCopy(dstTensor, srcTensor, intriParams); +} + +template +__aicore__ inline void CopyGmToCbuf(__cbuf__ T *dst, __gm__ T *src, uint8_t sid, uint16_t nBurst, + uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride, pad_t padMode) +{ + DataCopyParams intriParams(nBurst, lenBurst, srcStride, dstStride); + GlobalTensor srcTensor; + srcTensor.SetGlobalBuffer(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t logicpos = static_cast(TPosition::C1); + LocalTensor dstTensor; + dstTensor = CreateLocalTensor(dst_buffer_offset, logicpos); + DataCopy(dstTensor, srcTensor, intriParams); +} + + +template +__aicore__ inline void SetFpc(__fbuf__ T *src) +{ + LocalTensor tensor; + uint32_t src_buffer_offset = reinterpret_cast(src); + tensor = CreateLocalTensor(src_buffer_offset); + SetFixPipeConfig(tensor); +} + + +template +__aicore__ inline void LoadCbufToCaTranspose(__ca__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat, + uint16_t srcStride, uint16_t dstStride, bool addrmode, + uint16_t dstFracStride) +{ + LoadData2dTransposeParams params( + indexID, + repeat, + srcStride, + dstStride, + dstFracStride, + addrmode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); + uint8_t dst_logicpos = static_cast(TPosition::A2); + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadDataWithTranspose(dstTensor, srcTensor, params); +} + +template +__aicore__ inline void LoadCbufToCbTranspose(__cb__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat, + uint16_t srcStride, uint16_t dstStride, bool addrmode, + uint16_t dstFracStride) +{ + LoadData2dTransposeParams params( + indexID, + repeat, + srcStride, + dstStride, + dstFracStride, + addrmode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); + uint8_t dst_logicpos = static_cast(TPosition::B2); + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadDataWithTranspose(dstTensor, srcTensor, params); +} + +template +__aicore__ inline void LoadCbufToCa(__ca__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat, + uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose, + uint8_t addr_cal_mode) +{ + LoadData2dParams params( + baseIdx, + repeat, + srcStride, + sid, + dstStride, + transpose, + addr_cal_mode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); + uint8_t dst_logicpos = static_cast(TPosition::A2); + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadData(dstTensor, srcTensor, params); +} + + +template +__aicore__ inline void LoadCbufToCb(__cb__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat, + uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose, + uint8_t addr_cal_mode) +{ + LoadData2dParams params( + baseIdx, + repeat, + srcStride, + sid, + dstStride, + transpose, + addr_cal_mode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); + uint8_t dst_logicpos = static_cast(TPosition::B2); + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadData(dstTensor, srcTensor, params); +} + + +__aicore__ inline void GetBlockIdx(int32_t loop_idx, int32_t m_loop, int32_t n_loop, int32_t swizzl_direction, + int32_t swizzl_count, int64_t &m_idx, int64_t &n_idx) +{ + uint32_t in_batch_idx = loop_idx % (m_loop * n_loop); + if (swizzl_direction == 0) { + uint32_t tile_block_loop = (m_loop + swizzl_count - 1) / swizzl_count; + uint32_t tile_block_idx = in_batch_idx / (swizzl_count * n_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * n_loop); + + uint32_t n_row = swizzl_count; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzl_count * tile_block_idx; + } + m_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if (tile_block_idx % TILE_BLOCK_MOD != 0) { + n_idx = n_loop - n_idx - 1; + } + } else if (swizzl_direction == 1) { + uint32_t tile_block_loop = (n_loop + swizzl_count - 1) / swizzl_count; + uint32_t tile_block_idx = in_batch_idx / (swizzl_count * m_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * m_loop); + + uint32_t n_col = swizzl_count; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzl_count * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_col; + if (tile_block_idx % TILE_BLOCK_MOD != 0) { + m_idx = m_loop - m_idx - 1; + } + } +} + +template +__aicore__ inline void FFTSCrossCoreSync(uint64_t mode, uint64_t flag_id) +{ + uint64_t config = 1 | (mode << 4) | (flag_id << 8); + ffts_cross_core_sync(pipe, config); +} + + +template +__aicore__ GlobalTensor CreateGlobalTensor(__gm__ T *addr) +{ + GlobalTensor tensor; + tensor.SetGlobalBuffer(addr); + return tensor; +} + +#endif // MATMUL_ALLREDUCE_ADD_RMSNORM_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 706b711b..3709f91c 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -807,6 +807,36 @@ at::Tensor npu_sparse_flash_attention( output); return output; } +std::tuple matmul_allreduce_add_rmsnorm( + const at::Tensor &x1, + const at::Tensor &x2, + const at::Tensor &residual, + const at::Tensor &gamma, + c10::string_view group_tp, + int64_t tp_rank_size, + int64_t tp_rank_id, + double epsilon, + bool is_trans_b, + bool is_gather_add_out) + { + at::Tensor output = at::empty_like(residual); + at::Tensor add_out = at::empty_like(residual); + + std::string group_tp_str(group_tp); + + char *group_tp_ptr = group_tp_str.data(); + + float epsilon_f = static_cast(epsilon); + EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm, + // input + x1, x2, residual, gamma, + // attr + group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out, + // output + output, add_out); + + return {output, add_out}; + } } // namespace vllm_ascend @@ -921,4 +951,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int max_output_size, Tensor! out) -> Tensor" ); ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine); + + ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \ + str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)"); + ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index aa776fc2..50b28e5d 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -264,6 +264,23 @@ at::Tensor npu_sparse_flash_attention_meta( at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); return output; } +std::tuple matmul_allreduce_add_rmsnorm_meta( + const at::Tensor &x1, + const at::Tensor &x2, + const at::Tensor &residual, + const at::Tensor &gamma, + c10::string_view group_tp, + int64_t tp_rank_size, + int64_t tp_rank_id, + double epsilon, + bool is_trans_b, + bool is_gather_add_out) + { + at::Tensor output = at::empty_like(residual); + at::Tensor add_out = at::empty_like(residual); + + return {output, add_out}; + } } // namespace meta } // namespace vllm_ascend @@ -296,5 +313,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta); // MoE dispatch-ffn-combine ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta); + // matmul allreduce add rmsnorm + ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta); } } diff --git a/tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py b/tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py new file mode 100644 index 00000000..762802a1 --- /dev/null +++ b/tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py @@ -0,0 +1,135 @@ +import gc +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu +import torchair + +from vllm_ascend.utils import enable_custom_op + +config = torchair.CompilerConfig() +config.mode = "reduce-overhead" +npu_backend = torchair.get_npu_backend(compiler_config=config) +torch_npu.npu.config.allow_internal_format = True +enable_custom_op() + +global_rank_id = 0 + + +def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon): + c_ret = torch.nn.functional.linear(a, b) + dist.all_reduce(c_ret) + rmsnorm_ret, _, add_ret = torch_npu.npu_add_rms_norm( + c_ret, residual, gamma, epsilon) + return rmsnorm_ret, add_ret + + +def worker(rank, ep_world_size, batch_size, m, k, n): + global global_rank_id + global_rank_id = rank + rank = rank + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group(backend="hccl", + rank=rank, + world_size=ep_world_size) + + ep_ranks_list = list(np.arange(0, ep_world_size)) + + ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list) + + torch_npu.npu.set_device(rank) + ep_hcomm_info = ep_group._get_backend( + torch.device("npu")).get_hccl_comm_name(rank) + + torch_npu.npu.synchronize(rank) + + class Module(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x1, x2, residual, gamma, ep_hcomm_info, epsilon, + is_trans_b, is_allgather_add_out): + out1, add_out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm( + x1, x2, residual, gamma, ep_hcomm_info, ep_world_size, + global_rank_id, epsilon, is_trans_b, is_allgather_add_out) + return out1, add_out1 + + DTYPE = torch.bfloat16 + USE_ONES = False + + torch.manual_seed(42) + + if USE_ONES: + x1 = torch.ones([m, k], dtype=DTYPE).npu(rank) + x2 = torch.ones([n, k], dtype=DTYPE).npu(rank) + else: + x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank) + x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank) + + if USE_ONES: + residual = torch.full([m, n], 2048, dtype=DTYPE).npu(rank) + else: + residual = torch.full([m, n], 0, dtype=DTYPE).npu(rank) + + gamma = torch.full([n], 1, dtype=DTYPE).npu(rank) + + epsilon = 1e-5 + is_trans_b = True + is_allgather_add_out = True + warnup_cnt = 5 + repeat_cnt = 10 + + def run_golden_case(loop_cnt): + for _ in range(loop_cnt): + golden_out, golden_add_out = golden_op_matmul_allreduce_add_rmsnorm( + x1, x2, residual, gamma, epsilon) + torch_npu.npu.synchronize(rank) + return golden_out, golden_add_out + + run_golden_case(warnup_cnt) + + golden_out, golden_add_out = run_golden_case(repeat_cnt) + golden_out = golden_out.detach().cpu() + golden_add_out = golden_add_out.detach().cpu() + + mod = Module().npu() + opt_model = torch.compile(mod, backend=npu_backend) + + def run_custom_case(loop_cnt): + for _ in range(loop_cnt): + out, add_out = opt_model(x1, x2, residual, gamma, ep_hcomm_info, + epsilon, is_trans_b, is_allgather_add_out) + torch_npu.npu.synchronize(rank) + return out, add_out + + # warn up + run_custom_case(warnup_cnt) + + out, add_out = run_custom_case(repeat_cnt) + out = out.detach().cpu() + add_out = add_out.detach().cpu() + + dist.destroy_process_group() + + torch.testing.assert_close(golden_out, out, atol=0.1, rtol=0.005) + torch.testing.assert_close(golden_add_out, add_out, atol=0.1, rtol=0.005) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + +@torch.inference_mode() +def test_matmul_allreduce_add_rmsnorm_kernel(): + ep_world_size = 8 + batch_size = 1 + m = 10000 + k = 1024 + n = 5120 + args = (ep_world_size, batch_size, m, k, n) + mp.spawn(worker, args=args, nprocs=ep_world_size, join=True)