From 1ed9524763590e5ba509c0cd07c80113123e7bf6 Mon Sep 17 00:00:00 2001 From: guanguan0308 <162653673+guanguan0308@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:30:30 +0800 Subject: [PATCH] add dispath_ffn_combine_bf16 (#5866) ### What this PR does / why we need it? add dispath_ffn_combine_bf16 - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: guanguan0308 <1546542263@qq.com> --- csrc/build_aclnn.sh | 8 + .../op_host/CMakeLists.txt | 66 ++ .../aclnn_dispatch_ffn_combine_bf16.cpp | 84 ++ .../op_host/aclnn_dispatch_ffn_combine_bf16.h | 39 + .../op_host/dispatch_ffn_combine_bf16_def.cpp | 88 ++ .../dispatch_ffn_combine_bf16_proto.cpp | 40 + .../dispatch_ffn_combine_bf16_tiling.cpp | 278 +++++ .../op_host/error_log.h | 47 + .../op_host/hcom_topo_info.h | 72 ++ .../op_host/tiling_args.h | 9 + .../op_kernel/dispatch_ffn_combine_bf16.cpp | 51 + .../op_kernel/dispatch_ffn_combine_bf16.h | 289 +++++ .../dispatch_ffn_combine_bf16_kernel.hpp | 1056 +++++++++++++++++ .../dispatch_ffn_combine_bf16_tiling.h | 56 + .../moe_init_routing_v2.cpp | 125 ++ .../moe_init_routing_v2_tiling.h | 557 +++++++++ .../moe_init_routing_v2/moe_v2_common.h | 201 ++++ .../moe_v2_expert_token_out.h | 380 ++++++ .../moe_init_routing_v2/moe_v2_gather_out.h | 198 ++++ .../moe_v2_init_routing_fullload.h | 388 ++++++ .../moe_init_routing_v2/moe_v2_mrgsort.h | 211 ++++ .../moe_init_routing_v2/moe_v2_mrgsort_out.h | 245 ++++ .../moe_init_routing_v2/moe_v2_sort_base.h | 74 ++ .../moe_v2_sort_multi_core.h | 507 ++++++++ .../moe_v2_sort_one_core.h | 226 ++++ .../moe_v2_src_to_dst_op.h | 173 +++ .../moe_v2_src_to_dst_op_simt.h | 96 ++ .../moe_v2_src_to_dst_with_capacity.h | 279 +++++ .../moe_init_routing_v2/tiling_base.h | 66 ++ .../op_kernel/unpermute/moe_token_unpermute.h | 376 ++++++ .../unpermute/moe_token_unpermute_tiling.h | 38 + .../utils/block_epilogue_pertoken_row.hpp | 208 ++++ .../utils/block_epilogue_pertoken_swiglu.hpp | 402 +++++++ .../utils/block_epilogue_pertoken_v2.hpp | 330 ++++++ ...block_mmad_preload_async_fixpipe_quant.hpp | 502 ++++++++ .../op_kernel/utils/const_args.hpp | 9 + .../op_kernel/utils/copy_gm_to_l1_custom.hpp | 40 + .../op_kernel/utils/copy_l0c_to_gm_custom.hpp | 47 + .../utils/dispatch_policy_custom.hpp | 53 + .../op_kernel/utils/get_tensor_addr.hpp | 16 + .../op_kernel/utils/hccl_shmem.hpp | 195 +++ .../op_kernel/utils/layout3d.hpp | 20 + .../op_kernel/utils/select_helper.hpp | 25 + csrc/torch_binding.cpp | 17 +- .../test_dispatch_ffn_combine_bf16.py | 234 ++++ 45 files changed, 8420 insertions(+), 1 deletion(-) create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/CMakeLists.txt create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_proto.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/error_log.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/hcom_topo_info.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_host/tiling_args.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_tiling.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2.cpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2_tiling.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_common.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_expert_token_out.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_gather_out.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_init_routing_fullload.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort_out.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_base.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_multi_core.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_one_core.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op_simt.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_with_capacity.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/tiling_base.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute_tiling.h create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_row.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_gm_to_l1_custom.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_l0c_to_gm_custom.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/dispatch_policy_custom.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/get_tensor_addr.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp create mode 100644 csrc/dispatch_ffn_combine_bf16/op_kernel/utils/select_helper.hpp create mode 100644 tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 7eba981b..8bfdda78 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -50,20 +50,28 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/" TARGET_FILE="$TARGET_DIR/$(basename "$HCCL_STRUCT_FILE_PATH")" + # for dispatch_ffn_combine_bf16 + SCRIPT_DIR_BF16=$(cd "$(dirname "$0")" && pwd) + TARGET_DIR_BF16="$SCRIPT_DIR_BF16/dispatch_ffn_combine_bf16/op_kernel/utils/" + TARGET_FILE_BF16="$TARGET_DIR_BF16/$(basename "$HCCL_STRUCT_FILE_PATH")" echo "*************************************" echo $HCCL_STRUCT_FILE_PATH echo "$TARGET_DIR" cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR" + cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR_BF16" sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE" sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE" + sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE_BF16" + sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE_BF16" CUSTOM_OPS_ARRAY=( "grouped_matmul_swiglu_quant_weight_nz_tensor_list" "lightning_indexer" "sparse_flash_attention" "dispatch_ffn_combine" + "dispatch_ffn_combine_bf16" "dispatch_gmm_combine_decode" "moe_combine_normal" "moe_dispatch_normal" diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/CMakeLists.txt b/csrc/dispatch_ffn_combine_bf16/op_host/CMakeLists.txt new file mode 100644 index 00000000..e73a00bd --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/CMakeLists.txt @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +set(_DISPATCH_FFN_INC_OPTS) +if (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include) +elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include) +elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include) +endif() +if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include) +endif() + +add_ops_compile_options( + OP_NAME DispatchFFNCombineBF16 + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror + -DHCCL_COMM + ${_DISPATCH_FFN_INC_OPTS} +) + +target_sources(op_host_aclnnInner PRIVATE + dispatch_ffn_combine_bf16_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_dispatch_ffn_combine_bf16.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + aclnn_dispatch_ffn_combine_bf16.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + aclnn_dispatch_ffn_combine_bf16.cpp + ) +endif () + +target_sources(optiling PRIVATE + dispatch_ffn_combine_bf16_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel +) + +target_sources(opsproto PRIVATE + dispatch_ffn_combine_bf16_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_ffn_combine_bf16.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp new file mode 100644 index 00000000..e699e6e6 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "aclnn_dispatch_ffn_combine_bf16.h" +#include +// #include "aclnn_kernels/common/op_error_check.h" +// #include "opdev/op_log.h" +// #include "opdev/common_types.h" +// #include "opdev/platform.h" +// #include "ophost/matmul_util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../op_host/error_log.h" +// using namespace op; + +// using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +static constexpr size_t TWO_DIMS = 2; +static constexpr int64_t KVALUE_MIN = 256; +static constexpr int64_t KVALUE_MAX = 65535; +static constexpr size_t HCCL_GROUP_NAME_MAX = 128U; +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; + +extern aclnnStatus aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + bool transB, bool weightNz, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor); +extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize, + aclOpExecutor *executor, aclrtStream stream); +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + + + +aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor) +{ + bool transB = false; + bool weightNz = true; + + aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group, + maxOutputSize, transB, weightNz, + out, workspaceSize, executor); + return ret; +} + +aclnnStatus aclnnDispatchFFNCombineBF16(void* workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16(workspace, workspaceSize, executor, stream); + return ret; +} +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h new file mode 100644 index 00000000..a14f61fb --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h @@ -0,0 +1,39 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_DISPATCH_FFN_COMBINE_BF16_ +#define OP_API_INC_DISPATCH_FFN_COMBINE_BF16_ + +#include + +#include "aclnn/aclnn_base.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor); + + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif // OP_API_INC_DISPATCH_FFN_COMBINE_BF16_ \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp new file mode 100644 index 00000000..00bf2320 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine_bf16_def.cpp + * \brief + */ +#include "register/op_def_registry.h" + +namespace ops { +class DispatchFFNCombineBF16 : public OpDef { + public: + explicit DispatchFFNCombineBF16(const char *name) : OpDef(name) { + this->Input("a") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("w1") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .IgnoreContiguous(); + this->Input("w2") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .IgnoreContiguous(); + this->Input("expertIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("scale1") + .ParamType(DYNAMIC) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("scale2") + .ParamType(DYNAMIC) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("probs") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + // 输出 + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Attr("group").AttrType(REQUIRED).String(); + this->Attr("M").AttrType(OPTIONAL).Int(); + this->Attr("transB").AttrType(OPTIONAL).Bool(false); + this->Attr("weightNz").AttrType(OPTIONAL).Bool(false); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + this->AICore().AddConfig("ascend910_93", aicore_config); + // this->AICore().AddConfig("ascend910b", aicore_config); + this->MC2().HcclGroup("group"); + } +}; + +OP_ADD(DispatchFFNCombineBF16); +} // namespace ops \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_proto.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_proto.cpp new file mode 100644 index 00000000..ce4e7ab5 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_proto.cpp @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_proto.cpp + * \brief + */ +#include +#include +// #include "../../common/ophost/op_util.h" +// #include "../../common/ophost/hcom_topo_info.h" +// #include "log/ops_log.h" + +using namespace ge; +namespace ops { +const size_t ATTR_GROUP = 0; +const size_t ATTR_RANK_SIZE = 1; +const size_t SUPPORT_DIM_SIZE = 2; + +static ge::graphStatus InferShapeDispatchFFNCombineBF16(gert::InferShapeContext* context) { + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeDispatchFFNCombineBF16(gert::InferDataTypeContext* context) { + // auto d_type = context->GetInputDataType(0); + // context->SetOutputDataType(0, d_type); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(DispatchFFNCombineBF16) + .InferShape(InferShapeDispatchFFNCombineBF16) + .InferDataType(InferDataTypeDispatchFFNCombineBF16); +} // namespace ops diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp new file mode 100644 index 00000000..482470a2 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_tiling.cpp @@ -0,0 +1,278 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +/*! + * \file dispatch_ffn_tiling.cpp + * \brief + */ +#include "vector" +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error_log.h" +#include "hcom_topo_info.h" +#include "register/op_def_registry.h" +#include "dispatch_ffn_combine_bf16_tiling.h" +#include +#include +#include +#include "moe_init_routing_v2/moe_init_routing_v2_tiling.h" + +using namespace AscendC; +using namespace ge; + +namespace { + const char *K_INNER_DEBUG = "DispatchFFNCombineBF16 Tiling Debug"; + constexpr uint32_t ATTR_GROUP_INDEX = 0; + constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1; + constexpr uint32_t ATTR_IS_TRANS_B = 2; + constexpr uint32_t ATTR_WEIGHT_NZ = 3; + constexpr uint64_t INIT_TILINGKEY = 1000000; + constexpr uint64_t TILINGKEY_TRANS_B = 1U; + constexpr uint64_t TILINGKEY_WEIGHT_NZ = 10; + constexpr uint32_t X_INDEX = 0; + constexpr uint32_t WEIGHT_INDEX = 1; + constexpr uint32_t WEIGHT2_INDEX = 2; + constexpr uint32_t EXPERTID_INDEX = 3; + constexpr uint32_t BLOCK_NUM = 20; + constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + +static int32_t CeilDev(int32_t num, int32_t div) +{ + if (div == 0) { + return 0; + } + return (num + div - 1) / div; +} + +static ge::graphStatus DispatchFFNCombineBF16CheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineBF16Info& info) +{ + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED); + + auto groupPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_INDEX)); + auto maxOutputSizePtr = attrs->GetAttrPointer(ATTR_MAX_OUTPUT_SIZE_INDEX); + auto is_trans_b = attrs->GetAttrPointer(ATTR_IS_TRANS_B); + auto weight_nz = attrs->GetAttrPointer(ATTR_WEIGHT_NZ); + OP_TILING_CHECK(groupPtr == nullptr || strlen(groupPtr) == 0, + OP_LOGE(K_INNER_DEBUG, "group is invalid."), return GRAPH_FAILED); + + OP_TILING_CHECK(is_trans_b == nullptr, + OP_LOGE(K_INNER_DEBUG, "is_trans_b is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(weight_nz == nullptr, + OP_LOGE(K_INNER_DEBUG, "weight_nz is invalid."), return GRAPH_FAILED); + + info.maxOutputSize = *maxOutputSizePtr; + info.isTransposeB = *is_trans_b; + info.isWeightNz = *weight_nz; + + int64_t rankSize; + (void)ge::HcomTopoInfo::Instance().GetGroupRankSize(groupPtr, rankSize); + info.worldSize = rankSize; + + OP_LOGD(K_INNER_DEBUG, "maxOutputSize=%d ", info.maxOutputSize); + OP_LOGD(K_INNER_DEBUG, "rankSize=%d ", info.worldSize); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchFFNCombineBF16CheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineBF16Info &info) +{ + const char *nodeName = context->GetNodeName(); + + const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX); + auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0); + uint32_t M = aStorageShape->GetStorageShape().GetDim(0); + uint32_t K = aStorageShape->GetStorageShape().GetDim(1); + + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum(); + uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1); + + uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1); + uint32_t listLen = 0; + while (true) { + auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen); + if (wTensorT == nullptr) {break;} + } + + uint32_t expertPerRank; + if (listLen == 1) { + expertPerRank = wTensor->GetStorageShape().GetDim(0); + } else { + expertPerRank = listLen; + } + + info.M = M; + info.N = N; + info.K = K; + info.expertPerRank = expertPerRank; + info.topK = topK; + info.listLen = listLen; + OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M); + OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K); + OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N); + OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank); + OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK); + OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchFFNCombineBF16GetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineBF16Info& info) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0U; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + info.aivNum = aivNum; + info.totalUbSize = ubSize; + + OP_LOGD(K_INNER_DEBUG, "aivNum=%d", info.aivNum); + OP_LOGD(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize); + + return ge::GRAPH_SUCCESS; +} + +void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineBF16Info &info) +{ + cocTilingData.m0 = 128; + cocTilingData.k0 = 256; + cocTilingData.n0 = 256; + cocTilingData.swizzleDirect = 1; + cocTilingData.swizzleOffset = 7; + cocTilingData.ubMoveNum = 16 * 1024; + cocTilingData.pValue = 1; + cocTilingData.commNpuSplit = info.worldSize; + cocTilingData.commDataSplit = 1; + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2; +} + +static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + OP_LOGI(nodeName, "Enter DispatchFFNCombineBF16 tiling func."); + + // 1. tilingData + DispatchFFNCombineBF16TilingData *tilingData = context->GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), + return ge::GRAPH_FAILED); + OP_LOGI(nodeName, "DispatchFFNCombineBF16 get tilingData."); + DispatchFFNCombineBF16Info& info = tilingData->dispatchFFNCombineBF16Info; + OP_LOGI(nodeName, "DispatchFFNCombineBF16 get tilingData info."); + + OP_TILING_CHECK(DispatchFFNCombineBF16CheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombineBF16 CheckAttrAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(DispatchFFNCombineBF16CheckShapeAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombineBF16 CheckShapeAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(DispatchFFNCombineBF16GetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombineBF16 GetPlatformInfoAndSetTiling Failed"), + return ge::GRAPH_FAILED); + + SetTilingData(tilingData->cocTiling, info); + + // 2. set blockDim + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto aicNum = ascendcPlatform.GetCoreNumAic(); + auto aivNum = ascendcPlatform.GetCoreNumAiv(); + blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + context->SetBlockDim(blockDim); + + // 3. set tilingKey + uint64_t tilingKey = INIT_TILINGKEY; + tilingKey += info.isTransposeB ? TILINGKEY_TRANS_B : 0; + tilingKey += info.isWeightNz ? TILINGKEY_WEIGHT_NZ : 0; + context->SetTilingKey(tilingKey); + + OP_LOGD(K_INNER_DEBUG, "tilingKey=%d", tilingKey); + + optiling::MoeInitRoutingV2TilingBase moeInitRoutingQuantV2TilingBase; + int64_t inuptXDtypeSize = sizeof(int16_t); + int64_t scaleDim0 = 0; + int64_t ubSize = 196352; + int64_t expertCapacity = 0; + int64_t expertNum = info.expertPerRank * info.worldSize; + int64_t activeNum = info.M * info.topK; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 2; + bool expertTokensBeforeCapacityFlag = false; + int64_t quantMode = 1; + uint32_t aivNumInitRouting = 2 * BLOCK_NUM; + moeInitRoutingQuantV2TilingBase.DoTiling(info.M, info.K, info.topK, expertCapacity, expertNum, activeNum, dropPadMode, + expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0, aivNumInitRouting, ubSize); + uint64_t initRoutingQuantTilingKey = moeInitRoutingQuantV2TilingBase.tilingKey_; + size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_; + + tilingData->cocTiling.moeInitRoutingQuantV2TilingData = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.vbsComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vmsMiddleComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.vmsMiddleComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.sortOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.sortOutComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.srcToDstComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstCapacityComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.moeInitRoutingTilingData.gatherOutComputeParamsOp; + tilingData->cocTiling.initRoutingQuantTilingKey = initRoutingQuantTilingKey; + // OP_LOGE(initRoutingTilingKey, " initRoutingTilingKey."); + OP_LOGD(K_INNER_DEBUG, "tilingKey=%ld", initRoutingQuantTilingKey); + + // 4. workspace + size_t *workSpaces = context->GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), + return ge::GRAPH_FAILED); + + uint32_t n2 = info.K; + uint32_t k2 = info.N / 2; + + uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) + + info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 + + info.maxOutputSize * sizeof(float) * 2 + + info.maxOutputSize * info.N * sizeof(int16_t) + + info.maxOutputSize * n2 * sizeof(int16_t) + + info.maxOutputSize * info.K * sizeof(int16_t) + + info.maxOutputSize * k2 * sizeof(int16_t) + + info.worldSize * sizeof(int32_t) * 16; + // std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + + // std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); + + workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace); + + + // 5. communication + auto attrs = context->GetAttrs(); + auto group = attrs->GetAttrPointer(static_cast(ATTR_GROUP_INDEX)); + uint32_t opType = 8U; + std::string algConfig = "AlltoAll=level0:fullmesh;level1:pairwise"; + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig); + mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling); + + OP_LOGI(nodeName, "Leave DispatchFFNCombineBF16 tiling func."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchFFNCombineBF16TilingFunc(gert::TilingContext* context) +{ + return DispatchFFNCombineBF16TilingFuncImpl(context); +} + +struct DispatchFFNCombineBF16CompileInfo {}; +ge::graphStatus TilingParseForDispatchFFNCombineBF16(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(DispatchFFNCombineBF16) + .Tiling(DispatchFFNCombineBF16TilingFunc) + .TilingParse(TilingParseForDispatchFFNCombineBF16); +} // namespace optiling \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/error_log.h b/csrc/dispatch_ffn_combine_bf16/op_host/error_log.h new file mode 100644 index 00000000..4ef02cd4 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/error_log.h @@ -0,0 +1,47 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/hcom_topo_info.h b/csrc/dispatch_ffn_combine_bf16/op_host/hcom_topo_info.h new file mode 100644 index 00000000..827d4c5b --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/hcom_topo_info.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ +#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ + +#include +#include + +using Status = int32_t; + +namespace ge { +static constexpr uint32_t COMM_MESH = 0b1U; +static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U); +static constexpr uint32_t COMM_RING = (COMM_MESH << 2U); +static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U); +class HcomTopoInfo { + public: + enum class TopoLevel { + L0 = 0, + L1, + MAX, + }; + struct TopoLevelDesc { + uint32_t comm_sets; + uint32_t rank_size; + }; + using TopoDescs = TopoLevelDesc[static_cast(TopoLevel::MAX)]; + struct TopoInfo { + int64_t rank_size; + void *notify_handle; + TopoDescs topo_level_descs; + }; + static HcomTopoInfo &Instance(); + bool TopoInfoHasBeenSet(const char_t *group); + bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info); + Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info); + Status GetGroupRankSize(const char_t *group, int64_t &rank_size); + TopoDescs *GetGroupTopoDesc(const char_t *group); + Status GetGroupNotifyHandle(const char_t *group, void *¬ify_handle); + void UnsetGroupTopoInfo(const char_t *group) { + const std::lock_guard lock(mutex_); + (void) rank_info_.erase(group); + } + + Status SetGroupOrderedStream(const char_t *group, void *stream); + Status GetGroupOrderedStream(const char_t *group, void *&stream); + void UnsetGroupOrderedStream(const char_t *group) { + const std::lock_guard lock(mutex_); + (void) group_to_ordered_stream_.erase(group); + }; + + Status SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream); + Status GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream); + void UnsetGroupOrderedStream(const int32_t device_id, const char_t *group); + private: + HcomTopoInfo() = default; + ~HcomTopoInfo() = default; + std::unordered_map rank_info_; + std::mutex mutex_; + std::unordered_map group_to_ordered_stream_; // Ordered stream for the communication domain + std::unordered_map> device_id_to_group_to_ordered_stream_; // Ordered stream for the communication domain +}; +} + +#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/tiling_args.h b/csrc/dispatch_ffn_combine_bf16/op_host/tiling_args.h new file mode 100644 index 00000000..950cbe90 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_host/tiling_args.h @@ -0,0 +1,9 @@ +#ifndef TILING_ARGS_H +#define TILING_ARGS_H +#include + +namespace Moe { +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; +constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL; +} // namespace Moe +#endif // TILING_ARGS_H diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp new file mode 100644 index 00000000..370ffd82 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file dispatch_ffn_combine.cpp + * \brief + */ +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "dispatch_ffn_combine_bf16_tiling.h" +#include "dispatch_ffn_combine_bf16.h" + +using namespace AscendC; +using namespace DispatchFFNCombineBF16Impl; +extern "C" __global__ __aicore__ void dispatch_ffn_combine_bf16(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs, + GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM) +{ + REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData); + if (TILING_KEY_IS(1000000)) { + KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); + DispatchFFNCombineBF16 op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000001)) { + KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); + DispatchFFNCombineBF16 op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000010)) { + KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); + DispatchFFNCombineBF16 op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000011)) { + KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); + DispatchFFNCombineBF16 op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h new file mode 100644 index 00000000..12bd7949 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h @@ -0,0 +1,289 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine.h + * \brief + */ + +#ifndef DISPATCH_FFN_COMBINE_BF16_H +#define DISPATCH_FFN_COMBINE_BF16_H + +using namespace AscendC; + +#include "kernel_operator.h" + +#include "utils/moe_distribute_base.h" + +#include "dispatch_ffn_combine_bf16_tiling.h" + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/block/block_epilogue.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/epilogue/tile/tile_elemwise_add.hpp" +#include "catlass/epilogue/tile/tile_elemwise_muls.hpp" +#include "catlass/gemm/block/block_mmad.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/kernel/matmul_epilogue.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/layout/layout.hpp" +// #include "catlass/status.hpp" + +#include "utils/select_helper.hpp" +#include "utils/const_args.hpp" +#include "dispatch_ffn_combine_bf16_kernel.hpp" +#include "moe_init_routing_v2/moe_init_routing_v2_tiling.h" + +using namespace Catlass; +using namespace AscendC::HcclContextDef; + +namespace DispatchFFNCombineBF16Impl { +#define TemplateMMA2AClass typename AType_, typename BType_, typename CType_, bool TB_, bool Nz_ +#define TemplateMMA2ACFunc AType_, BType_, CType_, TB_, Nz_ + +using namespace AscendC; +template +class DispatchFFNCombineBF16 { +public: + __aicore__ inline DispatchFFNCombineBF16() {}; + __aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, + GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM); + __aicore__ inline void Process(); + + +private: + GM_ADDR xGM_; + GM_ADDR weight1GM_; + GM_ADDR weight2GM_; + GM_ADDR expertIdGM_; + GM_ADDR scale1GM_; + GM_ADDR scale2GM_; + GM_ADDR probs_; + GM_ADDR outGM_; + GM_ADDR workspaceGM_; + + GM_ADDR moeInitRoutingQuantV2Scale = nullptr; + GM_ADDR moeInitRoutingQuantV2Offset = nullptr; + GM_ADDR expertTokensBeforeCapacity = nullptr; + + + TBuf uBuf_; + + int32_t rank; + int32_t rankSize; + int32_t aivNum; + + int32_t m0; + int32_t k0; + int32_t n0; + int32_t swizzlOffset; + int32_t swizzlDirect; + int32_t ubMoveNum; + int32_t pValue; + + int32_t commNpuSplit; + int32_t commDataSplit; + int32_t lenPerLoop; + + int32_t m; + int32_t k; + int32_t n; + int32_t topK; + int32_t expertPerRank; + int32_t maxOutputSize; + int32_t EP; + int32_t listLen; + + optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData; + uint64_t initRoutingQuantTilingKey; + + // Hccl hccl_; + +}; + + +template +__aicore__ inline void DispatchFFNCombineBF16::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, + GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) +{ + REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData); + auto tiling = (__gm__ DispatchFFNCombineBF16TilingData*)tilingGM; + GET_TILING_DATA(tilingData, tilingGM); + + xGM_ = xGM; + weight1GM_ = weight1GM; + weight2GM_ = weight2GM; + expertIdGM_ = expertIdGM; + scale1GM_ = scale1GM; + scale2GM_ = scale2GM; + probs_ = probs; + + outGM_ = outGM; + + workspaceGM_ = workspaceGM; + + aivNum = tilingData.dispatchFFNCombineBF16Info.aivNum; + + m = tilingData.dispatchFFNCombineBF16Info.M; + k = tilingData.dispatchFFNCombineBF16Info.K; + n = tilingData.dispatchFFNCombineBF16Info.N; + EP = tilingData.dispatchFFNCombineBF16Info.worldSize; + topK = tilingData.dispatchFFNCombineBF16Info.topK; + expertPerRank = tilingData.dispatchFFNCombineBF16Info.expertPerRank; + maxOutputSize = tilingData.dispatchFFNCombineBF16Info.maxOutputSize; + listLen = tilingData.dispatchFFNCombineBF16Info.listLen; + + m0 = tilingData.cocTiling.m0; + k0 = tilingData.cocTiling.k0; + n0 = tilingData.cocTiling.n0; + swizzlDirect = tilingData.cocTiling.swizzleDirect; + swizzlOffset = tilingData.cocTiling.swizzleOffset; + ubMoveNum = tilingData.cocTiling.ubMoveNum; + pValue = tilingData.cocTiling.pValue; + commNpuSplit = tilingData.cocTiling.commNpuSplit; + commDataSplit = tilingData.cocTiling.commDataSplit; + lenPerLoop = tilingData.cocTiling.lenPerLoop; + moeInitRoutingQuantV2TilingData = tilingData.cocTiling.moeInitRoutingQuantV2TilingData; + initRoutingQuantTilingKey = tilingData.cocTiling.initRoutingQuantTilingKey; + + auto contextGM0 = AscendC::GetHcclContext(); + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; + + rank = WinContext_->localUsrRankId; + rankSize = WinContext_->rankSize; +} + +template +__aicore__ inline void DispatchFFNCombineBF16::Process() +{ + // Define ArchTag + using ArchTag = Arch::AtlasA2; + constexpr bool enableUnitFlag = false; + constexpr bool enableShuffleK = true; + + uint32_t k2 = n/2; + uint32_t n2 = k; + + int64_t activeNum = 0; + int64_t expertCapacity = 0; + int64_t expertNum = expertPerRank * EP; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 2; + bool expertTokensBeforeCapacityFlag = false; + int64_t quantMode = 1; + + using LayoutA = layout::RowMajor; + using LayoutB = typename std::conditional< + Nz_, + layout::zN, + typename std::conditional::type + >::type; + + LayoutB layoutB1 = LayoutBInitializer::create(k, n); + LayoutB layoutB2 = LayoutBInitializer::create(k2, n2); + using LayoutC = layout::RowMajor; + using L1TileShape = typename std::conditional< + std::is_same_v, + GemmShape<128, 256, 512>, + GemmShape<128, 256, 256> + >::type; + + constexpr uint32_t workspaceStages = 2; + constexpr uint32_t preloadStages = 1; + constexpr uint32_t l1Stages = 2; + constexpr uint32_t l0AStages = 2; + constexpr uint32_t l0BStages = 2; + constexpr uint32_t l0CStages = 1; + + using DispatchPolicy = Gemm::MmadAtlasA2PreloadAsyncFixpipe< + preloadStages, + l1Stages, l0AStages, l0BStages, l0CStages, + enableUnitFlag, enableShuffleK + >; + + using L0TileShape = typename std::conditional< + std::is_same_v, + GemmShape<128, 256, 128>, + GemmShape<128, 256, 64> + >::type; + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + using D1Type = Gemm::GemmType; + + using D2Type = typename std::conditional< + std::is_same_v, + Gemm::GemmType, + Gemm::GemmType + >::type; + + using BlockMmad = Gemm::Block::BlockMmad; + constexpr uint32_t ubStages = 2; + + using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant; + + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using ElementMulType = Gemm::GemmType; + using TileElemWiseMuls = Epilogue::Tile::TileElemWiseMuls; + + using TileCopy1 = Epilogue::Tile::TileCopy; + using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; + + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; + using TileCopy2 = Epilogue::Tile::TileCopy; + using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<9, 1>; + using ElementGroupList = int64_t; + using MatmulKernel = Gemm::Kernel::DispatchFFNCombineBF16Kernel; + + LayoutA layoutA1{static_cast(m), static_cast(k)}; + LayoutA layoutA2{static_cast(m), static_cast(k2)}; + layout::VectorLayout layoutScale1{static_cast(n)}; + layout::VectorLayout layoutScale2{static_cast(n2)}; + layout::RowMajor layoutD1{static_cast(maxOutputSize), static_cast(k2)}; + layout::RowMajor layoutD2{static_cast(m*topK), static_cast(n2)}; + // Prepare params + + GemmCoord problemShape{static_cast(m), static_cast(n), static_cast(k)}; + + uint32_t epilogueCoreNum = aivNum; + uint32_t epilogueGranularity = expertPerRank - 2; + + typename MatmulKernel::Params params{ + problemShape, static_cast(EP), static_cast(listLen), static_cast(expertPerRank), static_cast(maxOutputSize), + static_cast(rank), static_cast(rankSize), + static_cast(topK), initRoutingQuantTilingKey, + epilogueCoreNum, epilogueGranularity, + xGM_, layoutA1, layoutA2, + weight1GM_, layoutB1, + weight2GM_, layoutB2, + scale1GM_, layoutScale1, + scale2GM_, layoutScale2, + outGM_, layoutD1, layoutD2, + expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset, + expertTokensBeforeCapacity, probs_, + workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData}; + //Call kernel + MatmulKernel kernel(params); + kernel(params); +} + +} // DispatchFFNCombineBF16Impl +#endif // DISPATCH_FFN_COMBINE_H + diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp new file mode 100644 index 00000000..b1c74aa8 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -0,0 +1,1056 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DISPATCH_FFN_COMBINE_BF16_KERNEL_HPP +#define DISPATCH_FFN_COMBINE_BF16_KERNEL_HPP + +#include "kernel_operator.h" + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +#include "utils/block_mmad_preload_async_fixpipe_quant.hpp" +#include "utils/copy_gm_to_l1_custom.hpp" +#include "utils/copy_l0c_to_gm_custom.hpp" +#include "utils/block_epilogue_pertoken_v2.hpp" +#include "utils/block_epilogue_pertoken_swiglu.hpp" +#include "utils/hccl_shmem.hpp" +#include "utils/const_args.hpp" +#include "utils/layout3d.hpp" +#include "utils/get_tensor_addr.hpp" + +#include "moe_init_routing_v2/moe_init_routing_v2_tiling.h" +#include "moe_init_routing_v2/moe_init_routing_v2.cpp" +#include "moe_init_routing_v2/moe_v2_init_routing_fullload.h" +#include "unpermute/moe_token_unpermute.h" + + +using namespace AscendC; + +namespace Catlass::Gemm::Kernel { + +template < + class BlockMmad_, + class BlockScheduler_, + class ElementGroupList_, + class BlockEpilogue1_, + class BlockEpilogue2_ +> +class DispatchFFNCombineBF16Kernel { +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + using ElementScale = uint64_t; + // using ElementScale = int32_t; + using LayoutScale = typename layout::VectorLayout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = typename layout::VectorLayout; + using BlockScheduler = BlockScheduler_; + using BlockEpilogue1 = BlockEpilogue1_; + using BlockEpilogue2 = BlockEpilogue2_; + using ElementD1 = typename BlockEpilogue1::ElementD; + using LayoutD1 = typename BlockEpilogue1::LayoutD; + using ElementD2 = typename BlockEpilogue2::ElementD; + using LayoutD2 = typename BlockEpilogue2::LayoutD; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + __gm__ ElementA *ptrA; + LayoutA layoutA; + LayoutA layoutA2; + GM_ADDR ptrB1; + LayoutB layoutB1; + GM_ADDR ptrB2; + LayoutB layoutB2; + GM_ADDR ptrScale1; + LayoutScale layoutScale1; + GM_ADDR ptrScale2; + LayoutScale layoutScale2; + __gm__ ElementD2 *ptrOutput; + LayoutD1 layoutD1; + LayoutD2 layoutD2; + GM_ADDR ptrWorkspace; + int32_t EP; + int32_t listLen; + int32_t expertPerRank; + uint32_t maxOutputSize; + uint32_t rank; + uint32_t rankSize; + int32_t ubMoveNum; + GM_ADDR symmetricPtr; + //-------------- + GM_ADDR expertIdx; + GM_ADDR moeInitRoutingQuantV2Scale; + GM_ADDR moeInitRoutingQuantV2Offset; + GM_ADDR expandedX; + GM_ADDR expandedRowIdx; + GM_ADDR expertTokensCountOrCumsum; + GM_ADDR expertTokensBeforeCapacity; + GM_ADDR dynamicQuantScale; + GM_ADDR probs; + int64_t topK; + uint64_t initRoutingQuantTilingKey; + uint32_t epilogueCoreNum; + uint32_t epilogueGranularity; + optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData; + //-------------- + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params( + GemmCoord problemShape_, + uint32_t EP_, uint32_t listLen_, uint32_t expertPerRank_, uint32_t maxOutputSize_, + uint32_t rank_, uint32_t rankSize_, int64_t topK_, + uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_, + GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_, + GM_ADDR ptrB1_, LayoutB layoutB1_, + GM_ADDR ptrB2_, LayoutB layoutB2_, + GM_ADDR ptrScale1_, LayoutScale layoutScale1_, + GM_ADDR ptrScale2_, LayoutScale layoutScale2_, + GM_ADDR ptrOutput_, LayoutD2 layoutD1_, LayoutD2 layoutD2_, + GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_, + GM_ADDR moeInitRoutingQuantV2Offset_, + GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_, + GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, + optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData_, + GM_ADDR symmetricPtr_ = nullptr + ) : problemShape(problemShape_), + EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), + rank(rank_), rankSize(rankSize_), topK(topK_), + initRoutingQuantTilingKey(initRoutingQuantTilingKey_), + epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_), + ptrB1(ptrB1_), layoutB1(layoutB1_), + ptrB2(ptrB2_), layoutB2(layoutB2_), + ptrScale1(ptrScale1_), layoutScale1(layoutScale1_), + ptrScale2(ptrScale2_), layoutScale2(layoutScale2_), + ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_), + expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), + moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), + expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_), + ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_), + moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_) + { + moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingData_.vbsComputeParamsOp; + moeInitRoutingQuantV2TilingData.vmsMiddleComputeParamsOp = moeInitRoutingQuantV2TilingData_.vmsMiddleComputeParamsOp; + moeInitRoutingQuantV2TilingData.sortOutComputeParamsOp = moeInitRoutingQuantV2TilingData_.sortOutComputeParamsOp; + moeInitRoutingQuantV2TilingData.srcToDstComputeParamsOp = moeInitRoutingQuantV2TilingData_.srcToDstComputeParamsOp; + moeInitRoutingQuantV2TilingData.srcToDstCapacityComputeParamsOp = moeInitRoutingQuantV2TilingData_.srcToDstCapacityComputeParamsOp; + moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingData_.gatherOutComputeParamsOp; + } + }; + + // Methods + CATLASS_DEVICE + DispatchFFNCombineBF16Kernel(Params const ¶ms) + { + if ASCEND_IS_AIC { + coreIdx = AscendC::GetBlockIdx(); + coreNum = AscendC::GetBlockNum(); + } + + if ASCEND_IS_AIV { + coreIdx = get_block_idx() + get_subblockid() * get_block_num(); + coreNum = get_block_num() * get_subblockdim(); + } + + initBuffer(params); + } + + CATLASS_DEVICE + ~DispatchFFNCombineBF16Kernel() + { + } + + template + CATLASS_DEVICE + void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE + void operator()(Params const ¶ms) + { + GMM1(params); + + AscendC::CrossCoreWaitFlag<0x2>(2); + + GMM2(params); + } + + + template <> + CATLASS_DEVICE + void operator()(Params const ¶ms) + { + Dispatch(params); + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + + CombineV2(params); + } + +private: + CATLASS_DEVICE void initBuffer(Params const ¶ms) { + #ifndef HCCL_COMM + shmem.initShmem(params.symmetricPtr, params.rank, params.rankSize); + #endif + workspaceInfo = WorkspaceInfo(params); + peermemInfo = PeermemInfo(params, shmem); + + cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); + + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); + + gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); + gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); + + gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); + gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2)); + + tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); + + tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); + preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank)); + } + + template + CATLASS_DEVICE void CopyGMToGM( + AscendC::GlobalTensor dst, + AscendC::GlobalTensor src, + int32_t elemNum, + int32_t ubMoveNum + ) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + constexpr int32_t BufferNum = 2; + int tmpBufferSize = 32 * 1024 / sizeof(T); // 32 KB + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + tmpBuffer1.SetSize(tmpBufferSize); + int tmpBufferOffset = 96 * 1024; // half of UB + AscendC::LocalTensor tmpBuffer2 = resource.ubBuf.template GetBufferByByte(tmpBufferOffset); + tmpBuffer2.SetSize(tmpBufferSize); + + // [ReduceScatter] 2. Pre Interface Sync + int pingpongId = 0; + auto processCount = CeilDiv(elemNum, ubMoveNum); + for (uint32_t processIndex = 0; processIndex < processCount; ++processIndex) { + uint32_t curProcessNum = (processIndex == processCount - 1) ? elemNum - ubMoveNum * (processCount - 1) : ubMoveNum; + AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2; + auto processOffset = processIndex * ubMoveNum; + + auto inputOffset = processOffset; + auto outputOffset = processOffset; + // [ReduceScatter] 2. Pre Interface Sync + AscendC::WaitFlag(EVENT_ID); + // [ReduceScatter] 3. Start shmem_mte_get_mem_nbi + copyGmToUb(buf, src[inputOffset], layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum}); + AscendC::SetFlag(EVENT_ID); + AscendC::WaitFlag(EVENT_ID); + copyUbToGm(dst[outputOffset], buf, layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum}); + + // [ReduceScatter] 4. Post Interface Sync + AscendC::SetFlag(EVENT_ID); + pingpongId = (pingpongId + 1) % BufferNum; + } + // [ReduceScatter] 4. Post Interface Sync + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + + CATLASS_DEVICE + void GetCumsumForMMAIV(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) + { + int32_t expertPerRankAligned = (expertPerRank + 8 - 1) / 8 * 8; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + AscendC::LocalTensor tmpResult = resource.ubBuf.template GetBufferByByte(EP * expertPerRank * sizeof(int32_t)); + #define U16(x) static_cast(x) + + AscendC::DataCopyPad( + tmpBuffer1, + tokenPerExpert[rankId * expertPerRank], + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0}, + {} + ); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + for (uint32_t i = 1; i < EP; ++i) { + AscendC::Add(tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[(i - 1) * expertPerRankAligned], expertPerRank); + AscendC::PipeBarrier(); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::DataCopyPad( + result, + tmpBuffer1, + {U16(EP), U16((expertPerRank) * sizeof(int32_t)), 0, 0} + ); + } + +CATLASS_DEVICE + void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRank(Params const ¶ms, int64_t localTokenPerExpertOffset){ + AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + AscendC::LocalTensor ubFloat = resource.ubBuf.template GetBufferByByte(0); + uint32_t numPerCore = params.EP * params.expertPerRank; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx == params.rank) { + continue; + } + AscendC::GlobalTensor srcAddress; + srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); + AscendC::GlobalTensor dstAddress; + __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); + dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); + + AscendC::SetFlag(EVENT_ID0); + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + + AscendC::WaitFlag(EVENT_ID0); + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx != params.rank) { + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + } else { + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + } + + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){ + AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + AscendC::LocalTensor ubFloat = resource.ubBuf.template GetBufferByByte(0); + uint32_t numPerCore = params.EP * params.expertPerRank; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx == params.rank) { + continue; + } + AscendC::GlobalTensor srcAddress; + srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); + AscendC::GlobalTensor dstAddress; + __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); + dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); + + AscendC::SetFlag(EVENT_ID0); + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + + AscendC::WaitFlag(EVENT_ID0); + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx != params.rank) { + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + } else { + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + int32_t prevSum = 0; + for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) { + prevSum += tmpBuffer(i); + } + preSumBeforeRank(dstEpIdx * 16) = prevSum; + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(preSumBeforeRank[dstEpIdx * 16]); + __asm__ __volatile__(""); + + } + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void GetSumPreRank(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, + uint32_t expertPerRank, uint32_t rankId, uint32_t EP) { + int32_t cursum = 0; + if (coreIdx < EP) { + for (int32_t i = 0; i < rankId * expertPerRank; i++) { + cursum += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i)); + } + result.SetValue(coreIdx * 16, cursum); + + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(result[coreIdx * 16]); + __asm__ __volatile__(""); + } + } + + CATLASS_DEVICE + void ResetTokenPerExpert(AscendC::GlobalTensor & tokenPerExpert, int32_t num) + { + if (coreIdx != coreNum - 1) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmp, 0, num); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert, tmp, num); + } + + CATLASS_DEVICE + void GMM1(Params const ¶ms){ + icache_preload(8); + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + uint32_t startCoreIdx = 0; + uint32_t syncGroupIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(0); + AicSyncAll(); + int64_t preCurrentmSum = 0; + int32_t syncLoopIdx = -1; + + __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; + __gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK]; + + int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; + for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { + weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr(loopIdx, params.ptrB1)); + scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale1)); + } + AscendC::PipeBarrier(); + + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preCurrentmSum >= params.maxOutputSize) { + currentM = 0; + } else if (preCurrentmSum + currentM >= params.maxOutputSize) { + currentM = params.maxOutputSize - preCurrentmSum; + } + AscendC::GlobalTensor gmB1; + AscendC::GlobalTensor gmS; + int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; + gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx])); + gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx])); + + AscendC::PipeBarrier(); + + if (currentM <= L1TileShape::M) { + gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB1 = params.layoutB1; + LayoutScale layoutScale = params.layoutScale1; + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(0); + } + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB1.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + if (currentM > 0) { + if constexpr (std::is_same_v) { + int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // 每个expert一组scale + blockMmad( + gmA[gmGroupOffsetA + gmOffsetA], layoutA, + gmB1[gmGroupOffsetB + gmOffsetB], layoutB1, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, + gmS[gmOffsetS], layoutScale, + actualBlockShape + ); + } else { + blockMmad( + gmA[gmGroupOffsetA + gmOffsetA], layoutA, + gmB1[gmGroupOffsetB + gmOffsetB], layoutB1, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, + gmS, layoutScale, + actualBlockShape + ); + } + } + } + + if ((groupIdx + 1) == params.epilogueGranularity && (groupIdx < params.expertPerRank - 1)) { + syncLoopIdx ++; + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + blockMmad.Finalize(syncLoopIdx, 1); + } + + preCurrentmSum += currentM; + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if (params.listLen == 1) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + blockMmad.Finalize(syncLoopIdx + 1, 1); + } + + CATLASS_DEVICE + void GMM2(Params const ¶ms) { + icache_preload(8); + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + uint32_t n2 = params.problemShape.k(); + uint32_t k2 = params.problemShape.n() / 2; + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + + uint32_t startCoreIdx = 0; + + AscendC::PipeBarrier(); + + int64_t preCurrentmSum = 0; + int32_t syncLoopIdx = -1; + uint32_t lastDequantExpertNum = params.expertPerRank; + + if (params.epilogueGranularity < params.expertPerRank) { + lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; + } + + __gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK]; + __gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK]; + int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; + for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { + weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(loopIdx, params.ptrB2)); + scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale2)); + } + AscendC::PipeBarrier(); + + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preCurrentmSum >= params.maxOutputSize) { + currentM = 0; + } else if (preCurrentmSum + currentM > params.maxOutputSize) { + currentM = params.maxOutputSize - preCurrentmSum; + } + AscendC::GlobalTensor gmB2; + AscendC::GlobalTensor gmS2; + AscendC::PipeBarrier(); + int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; + gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx])); + gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx])); + + if (currentM <= L1TileShape::M) { + gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + GemmCoord inGroupProblemShape{currentM, n2, k2}; // M N K + + LayoutA layoutA = params.layoutA2.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB2 = params.layoutB2; + LayoutScale layoutScale = params.layoutScale2; + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { + AscendC::CrossCoreWaitFlag<0x2>(2); + } + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + if (loopIdx + coreNum >= coreLoops) { + syncLoopIdx = groupIdx; + } + + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB2.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + if (currentM > 0) { + if constexpr (std::is_same_v) { + int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // 每个expert一组scale + blockMmad( + gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, + gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, + gmC2[gmGroupOffsetC + gmOffsetC], layoutC, + gmS2[gmOffsetS], layoutScale, + actualBlockShape, syncLoopIdx, 3 + ); + } else { + blockMmad( + gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, + gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, + gmC2[gmGroupOffsetC + gmOffsetC], layoutC, + gmS2, layoutScale, + actualBlockShape, syncLoopIdx, 3 + ); + } + } + } + preCurrentmSum += currentM; + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if (params.listLen == 1) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + } + + CATLASS_DEVICE + void Dispatch(Params const ¶ms) { + icache_preload(8); + int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); + GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; + uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t); + + moe_init_routing_v2(reinterpret_cast (params.ptrA), params.expertIdx, shmem() + peermemInfo.offsetA, + workspaceInfo.expandedRowIdx, localTokenPerExpert, params.expertTokensBeforeCapacity, + params.ptrWorkspace + expandedRowIdxOffset, + ¶ms.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey); + + AscendC::SyncAll(); + CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset); + + if (coreIdx == coreNum - 1) { + GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + + uint32_t curGroupOffset = 0; + int32_t prevSumBeforeRank = 0; + int32_t groupIdxDeq = 0; + int prevSum = 0; + if (coreIdx < params.EP) { + prevSum = preSumBeforeRank(coreIdx * 16); + } + uint32_t prevGroupSum1 = 0; + uint32_t dequantSum = 0; + int32_t syncLoopIdx = -1; + BlockEpilogue1 blockEpilogue(resource); + for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; + if (rowStart < params.maxOutputSize) { + uint32_t rows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + if (rowStart + rows > params.maxOutputSize) { + rows = params.maxOutputSize - rowStart; + } + uint32_t rowSrc = prevSum; + prevSum += rows; + GM_ADDR otherRankPtr = shmem(0, dstEpIdx); + AscendC::GlobalTensor gmRemoteA; + gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); + AscendC::GlobalTensor gmRemotePerTokenScale; + gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); + MatrixCoord offsetA{rowStart, 0}; + MatrixCoord shapeA{rows, params.problemShape.k()}; + MatrixCoord offsetPeer{rowSrc, 0}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); + CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); + if constexpr (std::is_same_v) { + CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); + } + } + } + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { + syncLoopIdx++; + AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + uint32_t dequantLen = prevGroupSum1 - dequantSum; + if (dequantLen >= params.maxOutputSize) { + dequantLen = dequantLen - params.maxOutputSize; + } + + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + if constexpr (std::is_same_v) { + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + } else { + blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], params.epilogueCoreNum); + } + } + prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { + dequantSum = 0; + } + } + syncLoopIdx ++; + + AscendC::SyncAll(); + + uint32_t lastDequantExpertNum = params.expertPerRank; + if (params.epilogueGranularity < params.expertPerRank) { + lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; + } + if (lastDequantExpertNum < params.expertPerRank) { + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + } + AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::SyncAll(); + if (prevGroupSum1 - dequantSum < params.maxOutputSize) { + uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum; + if (prevGroupSum1 >= params.maxOutputSize) { + dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); + } + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + if constexpr (std::is_same_v) { + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + } else { + blockEpilogue(gmC[gmOffsetC], shapeC, gmPermutedToken[gmOffsetD], coreNum); + } + } + + blockEpilogue.Finalize(); + } + + CATLASS_DEVICE + void CombineV2(Params const ¶ms) { + BlockScheduler blockScheduler; + uint32_t n2 = params.problemShape.k(); + uint32_t k2 = params.problemShape.n() / 2; + uint32_t startCoreIdx = 0; + int64_t gmGroupOffsetC = 0; + uint32_t aivCoreNum = coreNum; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aivCoreIdx = coreIdx; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t aivSubCoreIdx = get_subblockid(); + uint32_t preSrcExpertSum = 0; + + typename BlockEpilogue2::Params epilogueParams{ + static_cast(params.EP), + static_cast(params.expertPerRank), + static_cast(params.rank), + reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), + params.layoutD2, + static_cast(n2), + static_cast(L1TileShape::N), + shmem, + static_cast(peermemInfo.offsetD) + }; + + BlockEpilogue2 blockEpilogue(resource, epilogueParams); + + int32_t syncLoopIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx; + + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) { + + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + int32_t m0 = 32; + int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0; + int32_t aiv_m_rows = m_rows / 2; + if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) { + aiv_m_rows += 1; + } + uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset + if(aivSubCoreIdx == 1) { + m_offset += (m_rows / 2) * m0; + } + if (loopIdx == startLoopIdx) { + for (;syncLoopIdx <= groupIdx; syncLoopIdx++) { + int32_t flag_id = 3 + syncLoopIdx / 8; + AscendC::CrossCoreWaitFlag<0x2>(flag_id); + } + } + + for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) { + GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1}; + uint32_t actualm = m0; + if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){ + actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0; + } + GemmCoord realTileShape{actualm, actualBlockShape.n(), 1}; + if constexpr (std::is_same_v) { + blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank); + } else { + blockEpilogue(gmC2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank, mPreSumBeforeRank); + } + m_offset += m0; + } + } + + for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { + int32_t expertRankM = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + mPreSumBeforeRank[dstEpIdx] += expertRankM; + } + preSrcExpertSum += currentExpertM; + startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum; + } + + blockEpilogue.Finalize(); + AscendC::SyncAll(); + ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank); + shmem.CrossRankSync(); + + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, aivCoreNum); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); + } + + +private: + struct WorkspaceInfo { + GM_ADDR ptrA; + GM_ADDR ptrPerTokenScale; + GM_ADDR ptrcumsumMM; + GM_ADDR ptrC; + GM_ADDR ptrC2; + GM_ADDR ptrPermutedToken; + GM_ADDR ptrPerTokenScale2; + GM_ADDR expandedRowIdx; + GM_ADDR ptrTokenPerExpert; + GM_ADDR ptrSumBeforeRank; + + CATLASS_DEVICE + WorkspaceInfo(){} + + CATLASS_DEVICE + WorkspaceInfo(const Params & params) { + uint32_t k2 = params.problemShape.n() / 2; + uint32_t n2 = params.problemShape.k(); + int64_t workspaceOffset = 0; + expandedRowIdx = params.ptrWorkspace; + + workspaceOffset += AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t); + ptrcumsumMM = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + ptrPerTokenScale = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale); + ptrPerTokenScale2 = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale); + ptrTokenPerExpert = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + ptrC = params.ptrWorkspace + workspaceOffset; //7 + + workspaceOffset += params.maxOutputSize * params.problemShape.n() * sizeof(ElementC); + ptrC2 = params.ptrWorkspace + workspaceOffset; //8 + + workspaceOffset += params.maxOutputSize * n2 * sizeof(ElementC); + ptrA = params.ptrWorkspace + workspaceOffset; //9 + + workspaceOffset += params.maxOutputSize * params.problemShape.k() * sizeof(ElementA); + ptrPermutedToken = params.ptrWorkspace + workspaceOffset; //10 + + workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA); + ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset; + } + }; + + struct PeermemInfo { + int64_t offsetA; + int64_t offsetPeerPerTokenScale; + int64_t offsetPeerTokenPerExpert; + int64_t offsetD; + + CATLASS_DEVICE + PeermemInfo(){} + + CATLASS_DEVICE + PeermemInfo(const Params & params, const HcclShmem & shmem) { + offsetA = 0; + offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); + offsetD = offsetPeerPerTokenScale + MB_SIZE; + offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; + } + }; + + Arch::Resource resource; + + uint32_t coreIdx; + uint32_t coreNum; + + WorkspaceInfo workspaceInfo; + PeermemInfo peermemInfo; + + int64_t m_prevSumBeforeRank; + + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmC; + AscendC::GlobalTensor gmS; + + AscendC::GlobalTensor gmPermutedToken; + AscendC::GlobalTensor gmS2; + AscendC::GlobalTensor gmC2; + + AscendC::GlobalTensor gmPerTokenScale1; + AscendC::GlobalTensor gmPerTokenScale2; + + AscendC::GlobalTensor tokenPerExpert; + AscendC::GlobalTensor cumsumMM; + AscendC::GlobalTensor preSumBeforeRank; + uint32_t mPreSumBeforeRank[32] = {0}; + Layout3D tokenPerExpertLayout; + HcclShmem shmem; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // DISPATH_FFN_COMBINE_KERNEL_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_tiling.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_tiling.h new file mode 100644 index 00000000..64cc6ae3 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_tiling.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine_tiling.h + * \brief + */ + +#include "moe_init_routing_v2/moe_init_routing_v2_tiling.h" + +#ifndef ASCENDC_DISPATCH_FFN_COMBINE_BF16_TILING_H +#define ASCENDC_DISPATCH_FFN_COMBINE_BF16_TILING_H +struct DispatchFFNCombineBF16Info { + uint32_t M; + uint32_t K; + uint32_t N; + uint32_t expertPerRank; + uint32_t maxOutputSize; + uint32_t isTransposeB; + uint32_t isWeightNz; + uint32_t aivNum; + uint32_t totalUbSize; + uint32_t topK; + uint32_t worldSize; + uint32_t listLen; +}; + +struct CoCTiling { + int32_t m0 = -1; + int32_t k0 = -1; + int32_t n0 = -1; + int32_t swizzleDirect = -1; + int32_t swizzleOffset = -1; + int32_t ubMoveNum = -1; + int32_t pValue = -1; + int32_t commNpuSplit = -1; + int32_t commDataSplit = -1; + int32_t lenPerLoop = -1; + uint64_t initRoutingQuantTilingKey; + optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData; +}; + +struct DispatchFFNCombineBF16TilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + DispatchFFNCombineBF16Info dispatchFFNCombineBF16Info; + CoCTiling cocTiling; +}; +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2.cpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2.cpp new file mode 100644 index 00000000..22ad44d7 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2.cpp @@ -0,0 +1,125 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_init_routing_v2.cpp + * \brief + */ + + +#ifdef __DAV_C310__ +#include "arch35/moe_v2_mrgsort_out.h" +#include "arch35/moe_v2_mrgsort.h" +#include "arch35/moe_v2_sort_multi_core.h" +#include "arch35/moe_v2_sort_one_core.h" +#include "arch35/moe_v2_expert_token_out_regbase.h" +#include "arch35/moe_v2_expert_token_out_simt.h" +#include "arch35/moe_v2_src_to_dst_op_simt.h" +#include "arch35/moe_v2_src_to_dst_with_capacity_simt.h" +#include "arch35/moe_v2_gather_out_for_simt.h" +#else +#include "moe_v2_mrgsort_out.h" +#include "moe_v2_mrgsort.h" +#include "moe_v2_sort_multi_core.h" +#include "moe_v2_sort_one_core.h" +#include "moe_v2_expert_token_out.h" +#include "moe_v2_src_to_dst_op.h" +#include "moe_v2_src_to_dst_with_capacity.h" +#include "moe_v2_gather_out.h" +#include "moe_v2_init_routing_fullload.h" +#endif + +using namespace AscendC; +using namespace MoeInitRoutingV2; +using namespace optiling; + +template +__aicore__ inline void moe_init_routing_v2(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, + const MoeInitRoutingV2TilingData* tilingData, uint64_t tilingKey) +{ + + if (g_coreType == AIC) { + return; + } + + // GET_TILING_DATA(tilingData, tiling); + if (workspace == nullptr) { + return; + } + + GM_ADDR userWS = workspace; + if (userWS == nullptr) { + return; + } + // auto t = tilingData; + if (tilingKey == 20000) { + TPipe sortPipe; + MoeV2FullLoad op; + op.Init(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, userWS, tilingData, &sortPipe); + op.Process(); + sortPipe.Destroy(); + // trap(); + return; + } + + if (tilingKey == 10001 || tilingKey == 10011) { + TPipe sortPipe; + MoeV2SortOneCore op; + op.Init(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, userWS, tilingData, + &sortPipe); + op.Process(); + sortPipe.Destroy(); + } else if (tilingKey == 10002 || tilingKey == 10012) { + TPipe sortPipe; + MoeV2SortMultiCore op; + op.Init(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, userWS, tilingData, + &sortPipe); + op.Process(); + sortPipe.Destroy(); + } + + if (tilingKey == 10001 || tilingKey == 10002) { + if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) { + TPipe expertTokenOutPipe; + MoeV2ExpertTokenOut expertTokenOutOp; + expertTokenOutOp.Init(expertTokensCountOrCumsum, expertTokensBeforeCapacity, + expandedRowIdx, userWS, tilingData, &expertTokenOutPipe); + expertTokenOutOp.Process(); + expertTokenOutPipe.Destroy(); + } + TPipe srcToDstPipe; + MoeV2SrcToDstOp srcToDstOp; + srcToDstOp.Init(expandedRowIdx, userWS, tilingData, &srcToDstPipe); + srcToDstOp.Process(); + srcToDstPipe.Destroy(); + } else if (tilingKey == 10011 || tilingKey == 10012) { + TPipe expertTokenOutPipe; + MoeV2ExpertTokenOut expertTokenOutOp; + expertTokenOutOp.Init(expertTokensCountOrCumsum, expertTokensBeforeCapacity, + expandedRowIdx, userWS, tilingData, &expertTokenOutPipe); + expertTokenOutOp.Process(); + expertTokenOutPipe.Destroy(); + + TPipe srcToDstPipe; + MoeV2SrcToDstWithCapacity srcToDstWithCapacityOp; + srcToDstWithCapacityOp.Init(expandedRowIdx, expandedX, userWS, tilingData, &srcToDstPipe); + srcToDstWithCapacityOp.Process(); + srcToDstPipe.Destroy(); + } + + TPipe gatherPipe; + MoeV2GatherOut gatherOp; + gatherOp.Init(x, expandedRowIdx, expandedX, userWS, tilingData, &gatherPipe); + gatherOp.Process(); + gatherPipe.Destroy(); + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2_tiling.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2_tiling.h new file mode 100644 index 00000000..6c3fd289 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_init_routing_v2_tiling.h @@ -0,0 +1,557 @@ +#pragma once + +#include "tiling_base.h" + + +namespace optiling { +const static int64_t TILING_KEY_DROPLESS_SORT_ONE_CORE = 10001; +const static int64_t TILING_KEY_DROPLESS_SORT_MULTI_CORE = 10002; +const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE = 10011; +const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE = 10012; +const static int64_t TILING_KEY_HIGH_PERFORMANCE = 20000; +const static int64_t NUM_TWO = 2; +const static int64_t NUM_THREE = 3; +const static int64_t NUM_FOUR = 4; +const static int64_t MRG_LIST_NUM = 4; +const static int64_t SORT32_ALIGN_ELEMENT = 32; +const static int64_t ONE_BLOCK_BYTE = 32; +const static size_t DIM_ONE = 1; +const static size_t DIM_TWO = 2; +const static size_t DIM_THREE = 3; +const static int32_t SIZE_16 = 16; +const static int32_t LENGTH_1024 = 1024; +const static int64_t MAX_COLS_ONE_LOOP = 16376; +const static int64_t ASSIST_NUM = 256; +const static int64_t INDEX_INPUT_X = 0; +const static int64_t INDEX_INPUT_EXPERT_IDX = 1; +const static int64_t ATTR_ACTIVE_ROWS = 0; +const static int64_t ATTR_EXPERT_CAPACITY = 1; +const static int64_t ATTR_EXPERT_NUM = 2; +const static int64_t ATTR_DROP_PAD_MODE = 3; +const static int64_t ATTR_EXPERT_TOKENS_COUNT_OR_CUMSUM_FLAG = 4; +const static int64_t ATTR_EXPERT_TOKENS_BEFORE_CAPACITY_FLAG = 5; +const static int64_t OUTOUT_EXPANDED_X = 0; +const static int64_t OUTOUT_EXPANDED_ROW_IDX = 1; +const static int64_t OUTOUT_EXPERT_TOKENS_COUNT_OR_CUMSUM = 2; +const static int64_t OUTOUT_EXPERT_TOKENS_BEFORE_CAPACITY = 3; +const static int64_t KV_FACTOR = 2; +const static int64_t ONE_CORE_SORT_BUFFER = 6; +const static int64_t EXPERT_TOKENS_COUNT = 2; +const static int64_t ONE_CORE_SORT_BUFFER_310P = 24; + + +inline static int64_t CeilLog4(int64_t x) { + return static_cast(std::ceil(std::log(x) / std::log(NUM_FOUR))); +} + +inline static int64_t GetPerOrLastValue(int64_t x, int64_t y) { + if (y == 0) { + return 0; + } + return x <= y ? x : x % y; +} + +template +constexpr T CeilDiv(const T dividend, const T divisor) +{ + return (dividend + divisor - 1) / divisor; +} + + +struct MoeV2VBSComputeTilingData { + int64_t needCoreNum = 0; + int64_t perCoreElements = 0; + int64_t perCoreLoops = 0; + int64_t perCorePerLoopElements = 0; + int64_t perCoreLastLoopElements = 0; + int64_t lastCoreElements = 0; + int64_t lastCoreLoops = 0; + int64_t lastCorePerLoopElements = 0; + int64_t lastCoreLastLoopElements = 0; + int64_t oneLoopMaxElements = 0; +}; + +struct MoeV2VMSMiddleComputeTilingData { + int64_t needCoreNum = 0; +}; + +struct MoeV2SortOutComputeTilingData { + int64_t oneLoopMaxElements = 0; +}; + +struct MoeV2GatherOutComputeTilingData { + int64_t needCoreNum = 0; + int64_t activateRows = 0; + int64_t perCoreRows = 0; + int64_t perCorePerLoopRows = 0; + int64_t perCoreLastLoopRows = 0; + int64_t lastCoreRows = 0; + int64_t lastCorePerLoopRows = 0; + int64_t lastCoreLastLoopRows = 0; + int64_t perCoreLoops = 0; + int64_t lastCoreLoops = 0; + int64_t perLoopCols = 0; + int64_t lastLoopCols = 0; + int64_t colLoops = 0; +}; + +struct MoeInitRoutingV2TilingData { + int64_t coreNum; + int64_t n; + int64_t cols; + int64_t k; + int64_t expertCapacity; + int64_t expertNum; + int64_t dropPadMode; + int64_t expertTokensCountOrCumsumFlag; + int64_t expertTokensBeforeCapacityFlag; + MoeV2VBSComputeTilingData vbsComputeParamsOp; + MoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp; + MoeV2SortOutComputeTilingData sortOutComputeParamsOp; + MoeV2GatherOutComputeTilingData srcToDstComputeParamsOp; + MoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp; + MoeV2GatherOutComputeTilingData gatherOutComputeParamsOp; +}; + + +class MoeInitRoutingV2TilingBase : public TilingBaseClass { + +protected: + bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) override; + bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override; + + bool DoOpTiling() override; + uint64_t GetTilingKey() const override; + bool GetWorkspaceSize() override; + + +protected: + bool CheckTokenCount(int64_t num, const char* tag); + + void Tiling4GatherOutCompute(); + void Tiling4SrcToDstCompute(); + void Tiling4SrcToDstCapacityCompute(); + void Tiling4SortOutCompute(); + void Tiling4VMSMiddleCompute(); + void Tiling4VBSCompute(); + void ShowTilingData(); + void Tiling4VBSMultiCoreCompute(MoeV2VBSComputeTilingData* tilingData); + void Tiling4VBSOneCoreCompute(MoeV2VBSComputeTilingData* tilingData); + bool IsFullLoad(); + + + + + int64_t aivNum = 0; + int64_t sortLoopMaxElement = 0; + int64_t mrgSortListMaxElement = 2040; + int64_t totalLength = 0; + int64_t activateNum = 0; + int64_t expertCapacity = 0; + int64_t expertNum = 0; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 0; + bool expertTokensBeforeCapacityFlag = false; + int64_t inuptXDtypeSize_ = 0; + bool isFullLoad = false; + const char *opName = "DispatchFFNCombine Tiling Debug"; + +public: + MoeInitRoutingV2TilingData moeInitRoutingTilingData; +}; + + +bool MoeInitRoutingV2TilingBase::DoOpTiling() { + sortLoopMaxElement = + (aicoreParams_.ubSize) / (sizeof(int32_t) * NUM_TWO * NUM_FOUR) / SORT32_ALIGN_ELEMENT * SORT32_ALIGN_ELEMENT; + isFullLoad = IsFullLoad(); + Tiling4VBSCompute(); + Tiling4VMSMiddleCompute(); + Tiling4SortOutCompute(); + Tiling4SrcToDstCompute(); + Tiling4SrcToDstCapacityCompute(); + Tiling4GatherOutCompute(); + ShowTilingData(); + return true; +}; + +uint64_t MoeInitRoutingV2TilingBase::GetTilingKey() const { + if (isFullLoad) { + return TILING_KEY_HIGH_PERFORMANCE; + } + if (dropPadMode == 0) { + if (totalLength <= sortLoopMaxElement) { + return TILING_KEY_DROPLESS_SORT_ONE_CORE; + } else { + return TILING_KEY_DROPLESS_SORT_MULTI_CORE; + } + } else { + if (totalLength <= sortLoopMaxElement) { + return TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE; + } else { + return TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE; + } + } + return tilingKey_; +} + + +void MoeInitRoutingV2TilingBase::ShowTilingData() +{ + OP_LOGD(opName, + "moeInitRoutingTilingData is coreNum:%ld, n:%ld, cols:%ld, k:%ld, expertCapacity:%ld, expertNum:%ld, " + "dropPadMode:%ld, expertTokensCountOrCumsumFlag:%ld, expertTokensBeforeCapacityFlag:%ld", + moeInitRoutingTilingData.coreNum, moeInitRoutingTilingData.n, + moeInitRoutingTilingData.cols, moeInitRoutingTilingData.k, + moeInitRoutingTilingData.expertCapacity, moeInitRoutingTilingData.expertNum, + moeInitRoutingTilingData.dropPadMode, moeInitRoutingTilingData.expertTokensCountOrCumsumFlag, + moeInitRoutingTilingData.expertTokensBeforeCapacityFlag); + OP_LOGD(opName, + "MoeV2VBSComputeTilingData is needCoreNum:%ld, perCoreElements:%ld, perCoreLoops:%ld, " + "perCorePerLoopElements:%ld, " + "perCoreLastLoopElements:%ld, lastCoreElements:%ld, lastCoreLoops:%ld, lastCorePerLoopElements:%ld, " + "lastCoreLastLoopElements:%ld, oneLoopMaxElements:%ld", + moeInitRoutingTilingData.vbsComputeParamsOp.needCoreNum, + moeInitRoutingTilingData.vbsComputeParamsOp.perCoreElements, + moeInitRoutingTilingData.vbsComputeParamsOp.perCoreLoops, + moeInitRoutingTilingData.vbsComputeParamsOp.perCorePerLoopElements, + moeInitRoutingTilingData.vbsComputeParamsOp.perCoreLastLoopElements, + moeInitRoutingTilingData.vbsComputeParamsOp.lastCoreElements, + moeInitRoutingTilingData.vbsComputeParamsOp.lastCoreLoops, + moeInitRoutingTilingData.vbsComputeParamsOp.lastCorePerLoopElements, + moeInitRoutingTilingData.vbsComputeParamsOp.lastCoreLastLoopElements, + moeInitRoutingTilingData.vbsComputeParamsOp.oneLoopMaxElements); + OP_LOGD(opName, "VMSMiddleComputeTilingData is needCoreNum:%ld", + moeInitRoutingTilingData.vmsMiddleComputeParamsOp.needCoreNum); + OP_LOGD(opName, "SortOutComputeTilingData is oneLoopMaxElements:%ld", + moeInitRoutingTilingData.sortOutComputeParamsOp.oneLoopMaxElements); + OP_LOGD( + opName, + "SrcToDstComputeTilingData is needCoreNum:%ld, activateRows:%ld, perCoreRows:%ld, perCorePerLoopRows:%ld, " + "perCoreLastLoopRows:%ld, lastCoreRows:%ld, lastCorePerLoopRows:%ld, lastCoreLastLoopRows:%ld,", + moeInitRoutingTilingData.srcToDstComputeParamsOp.needCoreNum, + moeInitRoutingTilingData.srcToDstComputeParamsOp.activateRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.perCoreRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.perCorePerLoopRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.perCoreLastLoopRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.lastCoreRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.lastCorePerLoopRows, + moeInitRoutingTilingData.srcToDstComputeParamsOp.lastCoreLastLoopRows); + OP_LOGD(opName, + "SrcToDstComputeCapacityTilingData is needCoreNum:%ld, perCoreRows:%ld, perCorePerLoopRows:%ld, " + "perCoreLastLoopRows:%ld, lastCoreRows:%ld, lastCorePerLoopRows:%ld, lastCoreLastLoopRows:%ld,", + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.needCoreNum, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.perCoreRows, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.perCorePerLoopRows, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.perCoreLastLoopRows, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.lastCoreRows, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.lastCorePerLoopRows, + moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.lastCoreLastLoopRows); + OP_LOGD( + opName, + "GatherOutComputeTilingData is needCoreNum:%ld, activateRows:%ld, perCoreRows:%ld, perCorePerLoopRows:%ld, " + "perCoreLastLoopRows:%ld, lastCoreRows:%ld, lastCorePerLoopRows:%ld, lastCoreLastLoopRows:%ld,", + moeInitRoutingTilingData.gatherOutComputeParamsOp.needCoreNum, + moeInitRoutingTilingData.gatherOutComputeParamsOp.activateRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.perCoreRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.perCorePerLoopRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.perCoreLastLoopRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.lastCoreRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.lastCorePerLoopRows, + moeInitRoutingTilingData.gatherOutComputeParamsOp.lastCoreLastLoopRows); +} + + + +bool MoeInitRoutingV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activateNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) { + + this->activateNum = activateNum; + this->expertCapacity = expertCapacity; + this->expertNum = expertNum; + this->dropPadMode = dropPadMode; + this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag; + if (dropPadMode == 1) { + expertTokensCountOrCumsumFlag = 0; + } else { + expertTokensBeforeCapacityFlag = false; + } + moeInitRoutingTilingData.cols = cols; + moeInitRoutingTilingData.n = m; + moeInitRoutingTilingData.k = topK; + moeInitRoutingTilingData.expertCapacity = expertCapacity; + moeInitRoutingTilingData.expertNum = expertNum; + moeInitRoutingTilingData.dropPadMode = dropPadMode; + moeInitRoutingTilingData.expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag; + moeInitRoutingTilingData.expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag; + totalLength = moeInitRoutingTilingData.n * moeInitRoutingTilingData.k; + inuptXDtypeSize_ = inuptXDtypeSize; + return true; +} + +bool MoeInitRoutingV2TilingBase::GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) { + aivNum = aivCoreNum; + aicoreParams_.blockDim = aivCoreNum; + aicoreParams_.ubSize = ubSizePlatForm; + moeInitRoutingTilingData.coreNum = aivCoreNum; + return true; +} + +bool MoeInitRoutingV2TilingBase::GetWorkspaceSize() { + size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; + size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO; + size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t); + workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024; + return true; +} + + +void MoeInitRoutingV2TilingBase::Tiling4VBSOneCoreCompute(MoeV2VBSComputeTilingData* tilingData) { + tilingData->needCoreNum = 1; + tilingData->perCoreElements = totalLength; + tilingData->perCoreLoops = 1; + tilingData->perCorePerLoopElements = tilingData->perCoreElements; + tilingData->perCoreLastLoopElements = tilingData->perCoreElements; + tilingData->lastCoreElements = tilingData->perCoreElements; + tilingData->lastCoreLoops = 1; + tilingData->lastCorePerLoopElements = tilingData->perCoreElements; + tilingData->lastCoreLastLoopElements = tilingData->perCoreElements; +} + +void MoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(MoeV2VBSComputeTilingData* tilingData) { + //Tiling4VBSMultiCoreCompute + int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); + needCoreNum = static_cast(std::pow(4, CeilLog4(needCoreNum))); + needCoreNum = std::min(needCoreNum, aivNum); + if (needCoreNum > 0) { + int64_t perCoreElements = totalLength / needCoreNum; + int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT; + int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements; + int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT; + if (lastCoreElement > alineCeilPerCoreElements) { + perCoreElements = alineCeilPerCoreElements; + needCoreNum = CeilDiv(totalLength, perCoreElements); + } else { + perCoreElements = alineFloorPerCoreElements; + } + tilingData->needCoreNum = needCoreNum; + do { + tilingData->perCoreElements = perCoreElements; + tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // 每个核处理的loop数 + tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement); + tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements; + tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements; + tilingData->lastCoreLoops = tilingData->perCoreLoops; + int64_t tmp = CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops); + int64_t lastCorePerLoopElements = + CeilDiv(CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops), SORT32_ALIGN_ELEMENT) * + SORT32_ALIGN_ELEMENT; + tilingData->lastCorePerLoopElements = lastCorePerLoopElements; + tilingData->lastCoreLastLoopElements = tilingData-> lastCoreElements - (tilingData->lastCoreLoops - 1) * tilingData->lastCorePerLoopElements; + perCoreElements -= SORT32_ALIGN_ELEMENT; + } while (tilingData->lastCoreLastLoopElements <= 0 && perCoreElements > 0); + } +} + +void MoeInitRoutingV2TilingBase::Tiling4VBSCompute() { + auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp; + tilingData->oneLoopMaxElements = sortLoopMaxElement; + if (totalLength <= sortLoopMaxElement) { + Tiling4VBSOneCoreCompute(tilingData); + return; + } + Tiling4VBSMultiCoreCompute(tilingData); +} + +void MoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() { + auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp; + auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp; + if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { + tilingData->needCoreNum = 0; + } else { + int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM); + tilingData->needCoreNum = needCoreNum; + } +} + +void MoeInitRoutingV2TilingBase::Tiling4SortOutCompute() { + auto tilingData = &moeInitRoutingTilingData.sortOutComputeParamsOp; + tilingData->oneLoopMaxElements = mrgSortListMaxElement; +} + + +void MoeInitRoutingV2TilingBase::Tiling4SrcToDstCompute() { + auto tilingData = &moeInitRoutingTilingData.srcToDstComputeParamsOp; + + int64_t perLoopMaxRows = (aicoreParams_.ubSize - ASSIST_NUM * sizeof(float) - aivNum * SORT32_ALIGN_ELEMENT) / + (SORT32_ALIGN_ELEMENT * NUM_TWO) / NUM_TWO; + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + int64_t needCoreNum = CeilDiv(totalLength, perCoreRows); + tilingData->needCoreNum = needCoreNum; + int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->perCoreRows = perCoreRows; + if (perLoopMaxRows >= tilingData->perCoreRows) { + tilingData->perCorePerLoopRows = tilingData->perCoreRows; + tilingData->perCoreLastLoopRows = tilingData->perCoreRows; + } else { + tilingData->perCorePerLoopRows = perLoopMaxRows; + tilingData->perCoreLastLoopRows = tilingData->perCoreRows - (CeilDiv(tilingData->perCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows; + } + tilingData->lastCoreRows = lastCoreNum; + if (perLoopMaxRows >= tilingData->lastCoreRows) { + tilingData->lastCorePerLoopRows = tilingData->lastCoreRows; + tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows; + } else { + tilingData->lastCorePerLoopRows = perLoopMaxRows; + tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows - (CeilDiv(tilingData->lastCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows; + } +} + +void MoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute() { + auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp; + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + int64_t needCoreNum = CeilDiv(totalLength, perCoreRows); + tilingData->needCoreNum = needCoreNum; + int64_t cols = moeInitRoutingTilingData.cols; + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + + + int64_t rowSize = + (perCoreRows * sizeof(int32_t) * 2 + ONE_BLOCK_BYTE + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t colSize = (cols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + + if (rowSize + colSize < static_cast(aicoreParams_.ubSize)) { + tilingData->perCorePerLoopRows = perCoreRows; + tilingData->perCoreLastLoopRows = perCoreRows; + tilingData->lastCorePerLoopRows = lastCoreRows; + tilingData->lastCoreLastLoopRows = lastCoreRows; + tilingData->perCoreLoops = 1; + tilingData->lastCoreLoops = 1; + tilingData->perLoopCols = cols; + tilingData->lastLoopCols = cols; + tilingData->colLoops = 1; + + } else { + int64_t baseMaxCols = MAX_COLS_ONE_LOOP; + int64_t baseMaxColsSize = (baseMaxCols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - baseMaxColsSize - ONE_BLOCK_BYTE) / + static_cast(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + if (cols < MAX_COLS_ONE_LOOP) { + basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - colSize - ONE_BLOCK_BYTE) / + static_cast(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = + (static_cast(aicoreParams_.ubSize) - rowSize) / inuptXDtypeSize_ / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } + tilingData->perLoopCols = (std::min(baseMaxCols, cols)); + tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols)); + tilingData->colLoops = ((cols + baseMaxCols - 1) / baseMaxCols); + tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLoops = ((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLoops = ((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + } +} + +void MoeInitRoutingV2TilingBase::Tiling4GatherOutCompute() +{ + auto tilingData = &moeInitRoutingTilingData.gatherOutComputeParamsOp; + tilingData->activateRows = totalLength; + if (dropPadMode == 0) { + tilingData->activateRows = activateNum; + } + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + if (perCoreRows <= 0 || moeInitRoutingTilingData.cols <= 0) { + tilingData->needCoreNum = 0; + return; + } + tilingData->needCoreNum = CeilDiv(totalLength, perCoreRows); + int64_t cols = moeInitRoutingTilingData.cols; + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + + int64_t rowSize = (perCoreRows * sizeof(int32_t) + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t colSize = (cols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + + if (rowSize + colSize < static_cast(aicoreParams_.ubSize) / NUM_TWO) { + tilingData->perCorePerLoopRows = perCoreRows; + tilingData->perCoreLastLoopRows = perCoreRows; + tilingData->lastCorePerLoopRows = lastCoreRows; + tilingData->lastCoreLastLoopRows = lastCoreRows; + tilingData->perCoreLoops = 1; + tilingData->lastCoreLoops = 1; + tilingData->perLoopCols = cols; + tilingData->lastLoopCols = cols; + tilingData->colLoops = 1; + } else { + int64_t baseMaxCols = MAX_COLS_ONE_LOOP; + int64_t baseMaxColsSize = + (baseMaxCols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) / NUM_TWO - baseMaxColsSize) / + static_cast(sizeof(int32_t)) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + if (cols < MAX_COLS_ONE_LOOP) { + basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) / NUM_TWO - colSize) / + static_cast(sizeof(int32_t)) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = (static_cast(aicoreParams_.ubSize) / NUM_TWO - rowSize) / inuptXDtypeSize_ / + ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } + tilingData->perLoopCols = (std::min(baseMaxCols, cols)); + tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols)); + tilingData->colLoops = ((cols + baseMaxCols - 1) / baseMaxCols); + + tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLoops = ((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + + tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLoops = ((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + } +} + + +bool MoeInitRoutingV2TilingBase::IsFullLoad() +{ + if (totalLength > sortLoopMaxElement || moeInitRoutingTilingData.cols > MAX_COLS_ONE_LOOP || + this->dropPadMode == 1) { + return false; + } + int64_t sortBufferNum = ONE_CORE_SORT_BUFFER; + + int64_t sortSpace = + CeilDiv(this->totalLength, SORT32_ALIGN_ELEMENT) * SORT32_ALIGN_ELEMENT * sizeof(int32_t) * sortBufferNum; + int64_t otherSpace = + CeilDiv(this->totalLength, SORT32_ALIGN_ELEMENT) * SORT32_ALIGN_ELEMENT * sizeof(int32_t) * NUM_THREE; + int64_t expertSpace = CeilDiv(this->expertNum * int64_t(sizeof(int32_t)), ONE_BLOCK_BYTE) * ONE_BLOCK_BYTE; + int64_t perCoreXRows = moeInitRoutingTilingData.n / aivNum; + int64_t remainder = moeInitRoutingTilingData.n % aivNum; + // NUM_TWO is Max xRows need add 2 becauseof the left and right row may be another row. + perCoreXRows = remainder <= 1 ? perCoreXRows + 1 : perCoreXRows + NUM_TWO; + int64_t gatherSpace = + CeilDiv(moeInitRoutingTilingData.cols * inuptXDtypeSize_, ONE_BLOCK_BYTE) * ONE_BLOCK_BYTE * perCoreXRows; + int64_t remainUbAfterSort = aicoreParams_.ubSize - sortSpace - otherSpace - expertSpace - gatherSpace; + return remainUbAfterSort > 0; +} + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_common.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_common.h new file mode 100644 index 00000000..1e8806b5 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_common.h @@ -0,0 +1,201 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_common.h + * \brief + */ +#ifndef MOE_V2_COMMON_H +#define MOE_V2_COMMON_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; +constexpr int64_t SPLIT_N = 0; +constexpr int64_t SPLIT_K = 1; +constexpr float MIN_FP32 = -3.4e38f; +#if __CCE_AICORE__ == 200 +constexpr int64_t ONE_REPEAT_SORT_NUM = 16; +constexpr int64_t REGIONP_ROPOSAL_KV_RATIO = 4; // 8 / 2 +constexpr int64_t SYNC_LEN = 8 * 8 * 32; +#else +constexpr int64_t ONE_REPEAT_SORT_NUM = 32; +#endif +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t INT32_ONE_BLOCK_NUM = 8; +constexpr int64_t FP32_ONE_REPEAT_NUM = 64; + +constexpr int64_t ASSIST_NUM = 256; +constexpr int64_t ASSIST_INDEX_NUM = 32; + +constexpr int64_t MERGE_LIST_TWO = 2; +constexpr int64_t MERGE_LIST_THREE = 3; +constexpr int64_t MERGE_LIST_FOUR = 4; + +constexpr int64_t MERGE_LIST_IDX_TWO = 2; +constexpr int64_t MERGE_LIST_IDX_THREE = 3; + +constexpr int64_t MAX_EXPERT_NUM = 5120; +constexpr int64_t DROPLESS_MODE = 0; +constexpr int64_t DROP_PAD_MODE = 1; +constexpr int64_t EXERPT_TOKENS_COUNT = 2; +constexpr int64_t EXERPT_TOKENS_CUMSUM = 1; +constexpr int64_t EXERPT_TOKENS_NONE = 0; +constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1; + +const __gm__ int32_t assist[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, + 4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, + 16, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, + 20, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0, + 24, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0}; + +__aicore__ inline int64_t Ceil(int64_t a, int64_t b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) +{ + if (bytes == 0) { + return 0; + } + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes; +} + +__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) +{ + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES; +} + +template +__aicore__ inline T Min(T a, T b) +{ + return a > b ? b : a; +} + +template +__aicore__ inline T Max(T a, T b) +{ + return a < b ? b : a; +} + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + SetFlag(eventId); + WaitFlag(eventId); +} + +template +__aicore__ inline void DataCopyPadCustom(LocalTensor inLocal, GlobalTensor srcGm, DataCopyExtParams tokenCopyParams, DataCopyPadExtParams padParams) +{ +#if __CCE_AICORE__ == 220 + DataCopyPad(inLocal, srcGm, tokenCopyParams, padParams); +#else + int64_t elem = tokenCopyParams.blockLen / sizeof(T); + int64_t numPerBlock = BLOCK_BYTES / sizeof(T); + int64_t alignElem = AlignUp(elem, numPerBlock); + + if (likely(alignElem == elem)) { + + DataCopyParams copyParams = {tokenCopyParams.blockCount, static_cast(alignElem / numPerBlock) , 0, 0}; + DataCopy(inLocal, srcGm, copyParams); + } else { + DataCopyParams copyParams = {1, static_cast(alignElem / numPerBlock) , 0, 0}; + for (uint32_t i = 0; i < tokenCopyParams.blockCount; i++) { + DataCopy(inLocal[i * alignElem], srcGm[i * elem], copyParams); + } + } +#endif +} + +template +__aicore__ inline void DataCopyCustom(GlobalTensor dstGm, LocalTensor inLocal, int64_t blockCount, int64_t blockLen) +{ + int64_t elem = blockLen / sizeof(T); + int64_t numPerBlock = sizeof(T) == 0 ? 1 : BLOCK_BYTES / sizeof(T); + int64_t alignElem = AlignUp(elem, numPerBlock); + + if (likely(alignElem == elem)) { + DataCopyParams copyParams = {static_cast(blockCount), static_cast(alignElem / numPerBlock) , 0, 0}; + DataCopy(dstGm, inLocal, copyParams); + } else { + if (blockCount == 1) { + if constexpr (needBack) { + int64_t elemAlignDown = numPerBlock == 0 ? 0 : elem / numPerBlock * numPerBlock; + if (elemAlignDown != 0) { + DataCopyParams copyParams = {static_cast(blockCount), static_cast(elemAlignDown / numPerBlock) , 0, 0}; + DataCopy(dstGm, inLocal, copyParams); + SetWaitFlag(HardEvent::MTE2_S); + SetWaitFlag(HardEvent::V_S); + + for (uint32_t i = 0; i < numPerBlock; i++) { + inLocal.SetValue(alignElem-1-i, inLocal.GetValue(elem - 1 - i)); + } + SetWaitFlag(HardEvent::S_MTE3); + + DataCopyParams copyParamslast = {1, 1, 0, 0}; + + DataCopy(dstGm[elem-numPerBlock], inLocal[elemAlignDown], copyParamslast); + } else { + T tmp[BLOCK_BYTES]; + SetWaitFlag(HardEvent::MTE2_S); + SetWaitFlag(HardEvent::V_S); + for (uint32_t i = 0; i < elem; i++) { + tmp[i] = inLocal.GetValue(elem - 1 - i); + } + DataCopyParams copyParamslast = {1, 1, 0, 0}; + SetWaitFlag(HardEvent::S_MTE2); + SetWaitFlag(HardEvent::MTE3_MTE2); + DataCopy(inLocal, dstGm[elem-numPerBlock], copyParamslast); + SetWaitFlag(HardEvent::MTE2_S); + for (uint32_t i = 0; i < elem; i++) { + inLocal.SetValue(numPerBlock-1-i, tmp[i]); + } + SetWaitFlag(HardEvent::S_MTE3); + DataCopy(dstGm[elem-numPerBlock], inLocal, copyParamslast); + } + + } else if constexpr (isAtomic) { + SetWaitFlag(HardEvent::MTE2_S); + SetWaitFlag(HardEvent::V_S); + for (uint32_t i = 0; i < alignElem - elem; i++) { + inLocal.SetValue(alignElem-1-i, T(0)); + } + SetWaitFlag(HardEvent::S_MTE3); + + DataCopyParams copyParams = {static_cast(blockCount), static_cast(alignElem / numPerBlock) , 0, 0}; + DataCopy(dstGm, inLocal, copyParams); + } else { + DataCopyParams copyParams = {static_cast(blockCount), static_cast(alignElem / numPerBlock) , 0, 0}; + DataCopy(dstGm, inLocal, copyParams); + } + } else { + DataCopyParams copyParams = {1, static_cast(alignElem / numPerBlock) , 0, 0}; + for (uint32_t i = 0; i < blockCount; i++) { + DataCopy(dstGm[i * elem], inLocal[i * alignElem], copyParams); + PipeBarrier(); + } + } + } +} + +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_COMMON_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_expert_token_out.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_expert_token_out.h new file mode 100644 index 00000000..6a022e97 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_expert_token_out.h @@ -0,0 +1,380 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_expert_token_out.h + * \brief + */ +#ifndef MOE_V2_EXPERT_TOKEN_OUT_H +#define MOE_V2_EXPERT_TOKEN_OUT_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +constexpr int64_t EXPERT_ID_VALUE_NUM = 2; + +class MoeV2ExpertTokenOut { +public: + __aicore__ inline MoeV2ExpertTokenOut(){}; + template + __aicore__ inline void Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR expandedRowIdx, GM_ADDR workspace, const TilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void Compute(int64_t progress); + __aicore__ inline void SyncAll(); + __aicore__ inline void InitLocal(); + __aicore__ inline void GetExpertTokenCount(int32_t curExpertId); + __aicore__ inline void CopyOutTokenGm(); + __aicore__ inline void CopyOutExpertTokensCumsum(bool isTail); + __aicore__ inline void CopyOutExpertTokensCount(bool isTail); + +private: + TPipe *pipe; + TQue copyInQueue; + TQue expertTokenIdxCopyInQueue; + TQue expertTokenIdxCopyOutQueue; + + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedRowIdxGm; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + GlobalTensor syncTmpSpaceGm_; + TBuf workBuffer; +#endif + LocalTensor expertTokenIdxOutLocal; + + const MoeV2GatherOutComputeTilingData *srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t expertNum; + int64_t expertNumUbAlign; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + + int64_t tokenCount = 0; + int64_t expertIdx = 0; + int32_t lastExpertId = -1; + int32_t firstExpertId = -1; + + int32_t expertTokenValue = 0; +}; + +__aicore__ inline void MoeV2ExpertTokenOut::InitLocal() +{ + LocalTensor tokenIdxLocal = expertTokenIdxCopyOutQueue.AllocTensor(); + Duplicate(tokenIdxLocal, 0, this->expertNumUbAlign); + expertTokenIdxCopyOutQueue.EnQue(tokenIdxLocal); + + // expandedRowIdx initialized to -1, which is used in the src_to_dst_with_capacity step. + // use this step SyncAll to synchronize every core data + if (this->dropPadMode == 0) { + return; + } + LocalTensor outLocal = copyInQueue.AllocTensor(); + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + Duplicate(outLocal, -1, perLoopRows); + SetWaitFlag(HardEvent::V_MTE3); + for (int64_t loop = 0; loop < loops; loop++) { + int64_t copyLength = perLoopRows; + if (loop == loops - 1) { + copyLength = lastLoopRows; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, + 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom( + expandedRowIdxGm[this->blockIdx * this->srcToDstTilingData->perCoreRows + loop * perLoopRows], outLocal, + copyParams.blockCount, copyParams.blockLen); +#else + DataCopyPad(expandedRowIdxGm[this->blockIdx * this->srcToDstTilingData->perCoreRows + loop * perLoopRows], + outLocal, copyParams); +#endif + } + SetWaitFlag(HardEvent::MTE3_MTE2); + copyInQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyIn(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.AllocTensor(); + DataCopy(inLocal, expandedExpertIdxGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t))); + copyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::GetExpertTokenCount(int32_t curExpertId) +{ + this->tokenCount++; + if (this->lastExpertId < curExpertId) { + this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount - 1); + this->tokenCount = 1; + this->expertIdx += (curExpertId - this->lastExpertId); + while (curExpertId - this->firstExpertId + 1 > this->expertNumUbAlign) { + SetWaitFlag(HardEvent::S_MTE3); + CopyOutExpertTokensCumsum(false); + CopyOutExpertTokensCount(false); + SetWaitFlag(HardEvent::MTE3_V); + Duplicate(this->expertTokenIdxOutLocal, 0, this->expertNumUbAlign); + SetWaitFlag(HardEvent::V_S); + this->firstExpertId += this->expertNumUbAlign; + this->expertIdx = curExpertId - this->firstExpertId; + } + this->lastExpertId = curExpertId; + } +} + +__aicore__ inline void MoeV2ExpertTokenOut::Compute(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = inLocal.GetValue(0); + this->firstExpertId = this->lastExpertId; + } + for (int64_t i = 0; i < currentLoopRows; i++) { + int32_t expertId = inLocal.GetValue(i); + GetExpertTokenCount(expertId); + } + this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount); + copyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCumsum(bool isTail) +{ + if (this->dropPadMode != DROPLESS_MODE || expertTokensCountOrCumsumFlag != EXERPT_TOKENS_CUMSUM) { + return; + } +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) { + return; + } +#endif + int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign; + int64_t end = this->expertNum - this->firstExpertId; + for (int64_t i = 0; i < copyLength; i++) { + this->expertTokenValue += this->expertTokenIdxOutLocal.GetValue(i); + this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue); + } + // if the remianing UB is sufficient, use the UB space to copy + // otherwise, copy the calculated data first, and then copy the last tokenValue to remaining expert position + if (isTail && end <= this->expertNumUbAlign) { + int64_t startAlign = Min(Align(copyLength, sizeof(int32_t)), end); + for (int64_t i = copyLength; i < startAlign; i++) { + this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue); + } + if (startAlign < end) { + Duplicate(this->expertTokenIdxOutLocal[startAlign], this->expertTokenValue, end - startAlign); + } + copyLength = end; + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, + 0}; + SetAtomicAdd(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom( + expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, + copyParams.blockCount, copyParams.blockLen); +#else + DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); +#endif + SetAtomicNone(); + if (isTail && end > this->expertNumUbAlign) { + int64_t remainderLength = end - copyLength; + SetWaitFlag(HardEvent::MTE3_V); + Duplicate(this->expertTokenIdxOutLocal, this->expertTokenValue, this->expertNumUbAlign); + SetWaitFlag(HardEvent::V_MTE3); + int64_t loopTimes = remainderLength / this->expertNumUbAlign + 1; + for (int64_t i = 0; i < loopTimes; i++) { + copyLength = i == loopTimes - 1 ? remainderLength - this->expertNumUbAlign * i : this->expertNumUbAlign; + DataCopyExtParams params{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, + 0, 0}; + SetAtomicAdd(); + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom( + expertTokensCountOrCumsumGm[this->lastExpertId + 1 + this->expertNumUbAlign * i], + this->expertTokenIdxOutLocal, + params.blockCount, params.blockLen); + +#else + DataCopyPad(expertTokensCountOrCumsumGm[this->lastExpertId + 1 + this->expertNumUbAlign * i], + this->expertTokenIdxOutLocal, params); +#endif + SetAtomicNone(); + } + } +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCount(bool isTail) +{ + int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign; + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, + 0}; +#ifdef __CCE_KT_TEST__ + // CPU孪生调试不进行输出拷贝 + return; +#endif + SetAtomicAdd(); + if (this->dropPadMode == DROP_PAD_MODE && expertTokensBeforeCapacityFlag > EXERPT_TOKENS_NONE) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expertTokensBeforeCapacityGm[this->firstExpertId], + this->expertTokenIdxOutLocal, copyParams.blockCount, copyParams.blockLen); + #else + DataCopyPad(expertTokensBeforeCapacityGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); +#endif + } + if (this->dropPadMode == DROPLESS_MODE && expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, + copyParams.blockCount, copyParams.blockLen); +#else + DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); +#endif + } + SetAtomicNone(); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutTokenGm() +{ + if (this->dropPadMode == DROPLESS_MODE) { + SetWaitFlag(HardEvent::S_MTE3); + CopyOutExpertTokensCumsum(true); + CopyOutExpertTokensCount(true); + return; + } + this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign, this->lastExpertId); + this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign + 1, this->tokenCount); + DataCopyExtParams copyParams{static_cast(1), static_cast(EXPERT_ID_VALUE_NUM * sizeof(int32_t)), + 0, 0, 0}; + SetWaitFlag(HardEvent::S_MTE3); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expertIdxValueGm[this->blockIdx * BLOCK_BYTES / sizeof(int32_t)], + this->expertTokenIdxOutLocal[this->expertNumUbAlign], copyParams.blockCount, copyParams.blockLen); +#else + DataCopyPad(expertIdxValueGm[this->blockIdx * EXPERT_ID_VALUE_NUM], + this->expertTokenIdxOutLocal[this->expertNumUbAlign], copyParams); +#endif + CopyOutExpertTokensCount(true); +} + +__aicore__ inline void MoeV2ExpertTokenOut::SyncAll() +{ + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor syncLocal = workBuffer.Get(); + AscendC::SyncAll(syncTmpSpaceGm_, syncLocal, GetBlockNum()); +#else + AscendC::SyncAll(); +#endif +#endif +} + +template +__aicore__ inline void MoeV2ExpertTokenOut::Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR expandedRowIdx, GM_ADDR workspace, + const TilingData *tilingData, TPipe *tPipe) +{ + int64_t blockNum = GetBlockNum() * 2; + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + // this->blockIdx = GetBlockIdx(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp); + this->expertNum = tilingData->expertNum; + this->dropPadMode = tilingData->dropPadMode; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + } + + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, Align(this->totalLength, sizeof(int32_t))); + if (this->dropPadMode == DROPLESS_MODE && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensCountOrCumsum, this->expertNum); + } + if (this->dropPadMode == DROP_PAD_MODE && this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { + expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensBeforeCapacity, this->expertNum); + } + + expandedExpertIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + syncTmpSpaceGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(this->totalLength, sizeof(int32_t)) * EXERPT_TOKENS_COUNT + + this->coreNum * BLOCK_BYTES / sizeof(int32_t), + INT32_ONE_BLOCK_NUM * GetBlockNum() * 2 * BLOCK_BYTES); + pipe->InitBuffer(workBuffer, INT32_ONE_BLOCK_NUM * GetBlockNum() * 2 * BLOCK_BYTES); + LocalTensor syncLocal = workBuffer.Get(); + Duplicate(syncLocal, 0, SYNC_LEN); + SetWaitFlag(HardEvent::V_MTE3); + DataCopy(syncTmpSpaceGm_, syncLocal, SYNC_LEN); + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + this->coreNum * BLOCK_BYTES / sizeof(int32_t)); +#else + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + this->coreNum * 2); +#endif + this->expertNumUbAlign = Min(Align(this->expertNum, sizeof(int32_t)), MAX_EXPERT_NUM); + pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES); + pipe->InitBuffer(expertTokenIdxCopyInQueue, 1, this->expertNumUbAlign * sizeof(int32_t)); + pipe->InitBuffer(expertTokenIdxCopyOutQueue, 1, (this->expertNumUbAlign + EXPERT_ID_VALUE_NUM) * sizeof(int32_t)); +} + +__aicore__ inline void MoeV2ExpertTokenOut::Process() +{ + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + currentLoopRows = perLoopRows; + InitLocal(); + this->expertTokenIdxOutLocal = expertTokenIdxCopyOutQueue.DeQue(); + for (int64_t loop = 0; loop < loops - 1; loop++) { + CopyIn(loop); + Compute(loop); + } + currentLoopRows = lastLoopRows; + CopyIn(loops - 1); + Compute(loops - 1); + CopyOutTokenGm(); + expertTokenIdxCopyOutQueue.FreeTensor(this->expertTokenIdxOutLocal); + } + this->SyncAll(); +} + +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_EXPERT_TOKEN_OUT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_gather_out.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_gather_out.h new file mode 100644 index 00000000..496c0bbd --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_gather_out.h @@ -0,0 +1,198 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_gather_out.h + * \brief + */ +#ifndef MOE_V2_GATHER_OUT_H +#define MOE_V2_GATHER_OUT_H + +#include "moe_v2_common.h" +#include "kernel_operator.h" +using namespace optiling; + +namespace MoeInitRoutingV2 { +using namespace AscendC; + +constexpr int64_t BUFFER_NUM = 2; + +template +class MoeV2GatherOut { +public: + __aicore__ inline MoeV2GatherOut(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const MoeInitRoutingV2TilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInIndices(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + +private: + TPipe *pipe; + TQueBind inputActivationsCopyInQueue; + TQue expandDstToSrcRowCopyInQueue; + + GlobalTensor inputXGm; + GlobalTensor expandedXGm; + GlobalTensor expandedRowIdxGm; + + const MoeV2GatherOutComputeTilingData *gatherOutTilingData; + + int64_t needCoreNum; + int64_t blockIdx; + int64_t cols; + int64_t n; + int64_t k; + int64_t activateRows; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t colsTileLength; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t dropPadMode; + + int64_t indicesOffset; + int64_t inputOffset; + int64_t outOffset; +}; + +template +__aicore__ inline void MoeV2GatherOut::CopyInIndices(int64_t progress) +{ + this->indicesOffset = progress * this->perLoopRows; + LocalTensor indicesLocal = expandDstToSrcRowCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); +#else + DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); +#endif + expandDstToSrcRowCopyInQueue.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherOut::CopyOut(int64_t progress) +{ + LocalTensor indicesLocal = expandDstToSrcRowCopyInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength = this->perLoopCols; + for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) { + int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + int64_t curLoopRow = 0; + if (colsLoop == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + int64_t currentLoopStartRow = initialRow / this->k; + int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = inputActivationsCopyInQueue.AllocTensor(); + // input row position + inputOffset = row * this->cols + colsLoop * this->perLoopCols; + DataCopyExtParams dataCopyParams{1, static_cast(this->colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams); +#else + DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams); +#endif + SetWaitFlag(HardEvent::MTE2_MTE3); + + DataCopyExtParams intriParams{1, static_cast(this->colsTileLength * sizeof(T)), 0, 0, 0}; + while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { + continue; + } + outOffset = outIndex * cols + colsLoop * this->perLoopCols; +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (outOffset > expandedXGm.GetSize()) { + continue; + } +#endif +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expandedXGm[outOffset], inLocal, intriParams.blockCount, intriParams.blockLen); +#else + DataCopyPad(expandedXGm[outOffset], inLocal, intriParams); +#endif + } + inputActivationsCopyInQueue.FreeTensor(inLocal); + } + } + expandDstToSrcRowCopyInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherOut::Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR workspace, const MoeInitRoutingV2TilingData *tilingData, + TPipe *tPipe) +{ + this->pipe = tPipe; + // this->blockIdx = GetBlockIdx(); + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp); + + this->needCoreNum = this->gatherOutTilingData->needCoreNum; + this->activateRows = this->gatherOutTilingData->activateRows; + this->cols = tilingData->cols; + this->n = tilingData->n; + this->k = tilingData->k; + this->dropPadMode = tilingData->dropPadMode; + + if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) { + this->coreRows = this->gatherOutTilingData->lastCoreRows; + this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->lastCoreLoops; + } else { + this->coreRows = this->gatherOutTilingData->perCoreRows; + this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->perCoreLoops; + } + this->perLoopCols = this->gatherOutTilingData->perLoopCols; + this->lastLoopCols = this->gatherOutTilingData->lastLoopCols; + this->colLoops = this->gatherOutTilingData->colLoops; + + inputXGm.SetGlobalBuffer((__gm__ T *)inputX, this->coreRows * this->cols); + expandedXGm.SetGlobalBuffer((__gm__ T *)expandedX, tilingData->n * tilingData->k * this->cols); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + + pipe->InitBuffer(inputActivationsCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T))); + pipe->InitBuffer(expandDstToSrcRowCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t))); +} + +template +__aicore__ inline void MoeV2GatherOut::Process() +{ + if (this->blockIdx < this->needCoreNum) { + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyInIndices(loop); + CopyOut(loop); + } + } +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_GATHER_OUT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_init_routing_fullload.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_init_routing_fullload.h new file mode 100644 index 00000000..bfe0f144 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_init_routing_fullload.h @@ -0,0 +1,388 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_v2_init_routing_fullload.h + * \brief + */ +#ifndef MOE_V2_FULL_LOAD_H +#define MOE_V2_FULL_LOAD_H + +#include "moe_v2_mrgsort.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +template +class MoeV2FullLoad : public MoeV2SortBase { +public: + __aicore__ inline MoeV2FullLoad(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingV2TilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOutIdx(); + __aicore__ inline void CopyOutEmpty(); + __aicore__ inline void CopyOutX(); + __aicore__ inline void ComputeExpertTokenCountOrCumsum(); + +private: + int64_t sortNum_; + const MoeV2GatherOutComputeTilingData *gatherOutTilingData_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t sortNeedCoreNum_; + int64_t coreRows_; + int64_t perCoreRows_; + int64_t k_; + int64_t n_; + int64_t cols_; + int64_t activateRows_; + int64_t expertNum; + int64_t expertCapacity; + + TQue xCopyInQueue_; + TQue expandedRowIdxCopyOutQueue_; + TQue expandedExpertIdxCopyOutQueue_; + TQue expandDstToSrcRowQueue_; + TQue expertTokensCopyOutQueue_; + + GlobalTensor xGm_; + GlobalTensor expertIdxGm_; + + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedExpertIdxGm_; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + int64_t dropPadMode = 0; +}; + +template +__aicore__ inline void MoeV2FullLoad::CopyIn() +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(this->totalLength * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams); +#else + DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams); +#endif + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + sortDataCopyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::SortCompute() +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_NONE, this->totalLength); +#else + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); +#endif + PipeBarrier(); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength); + PipeBarrier(); + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor concatLocal; +#else + LocalTensor concatLocal = expertIdxLocalFp32; +#endif + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum_)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast(); + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum_)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_NONE, + this->totalLength); +#else + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + this->totalLength); +#endif + PipeBarrier(); + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength); + PipeBarrier(); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + PipeBarrier(); + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdxLocalInt32); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandedRowIdxU32 = expandedRowIdx.ReinterpretCast(); + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength); + PipeBarrier(); + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + PipeBarrier(); + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + sortDataCopyInQueue.FreeTensor(inLocal); + + expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutIdx() +{ + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyExtParams intriParams; +#else + DataCopyParams intriParams; +#endif + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expandedRowIdxGm_, expandedRowIdx, intriParams.blockCount, intriParams.blockLen); +#else + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams); +#endif + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); +} + +template +__aicore__ inline void MoeV2FullLoad::ComputeExpertTokenCountOrCumsum() +{ + LocalTensor expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue(); + LocalTensor expertTokensCount = expertTokensCopyOutQueue_.AllocTensor(); + + int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t)); + Duplicate(expertTokensCount, 0, expertNumAlign); + SetWaitFlag(HardEvent::V_S); + + int32_t lastExpertId = expandedExpertIdx.GetValue(0); +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致lastExpertId为未初始化的脏数据,因此需要特殊处理 + if (lastExpertId > expertTokensCount.GetSize()) { + return; + } +#endif + int64_t tokenCount = 0; + int64_t lastExpertCount = 0; + for (int64_t i = 0; i < this->totalLength; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + tokenCount++; + while (lastExpertId < curExpertId) { + expertTokensCount.SetValue(lastExpertId, tokenCount - 1); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { + tokenCount = 1; + } + lastExpertId++; + } + } + expertTokensCount.SetValue(lastExpertId, tokenCount); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) { + lastExpertId++; + while (lastExpertId < this->expertNum) { + expertTokensCount.SetValue(lastExpertId, tokenCount); + lastExpertId++; + } + } + DataCopyExtParams copyParams{static_cast(1), static_cast(this->expertNum * sizeof(int32_t)), 0, + 0, 0}; + if (this->expertTokensCountOrCumsumFlag > 0) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expertTokensCountOrCumsumGm, expertTokensCount, copyParams.blockCount, copyParams.blockLen); +#else + DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams); +#endif + } + expertTokensCopyOutQueue_.FreeTensor(expertTokensCount); + expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutX() +{ + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->cols_ * sizeof(T); + int64_t inFactor = Align(this->cols_, sizeof(T)); + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_; + + DataCopyExtParams dataXCopyParams{static_cast(endXRow - startXRow + 1), + static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataXCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(xLocal, xGm_[startXRow * this->cols_], dataXCopyParams, dataXCopyPadParams); +#else + DataCopyPad(xLocal, xGm_[startXRow * this->cols_], dataXCopyParams, dataXCopyPadParams); +#endif + SetWaitFlag(HardEvent::MTE2_S); + + int64_t k = 0; + for (int64_t i = startXRow; i <= endXRow; i++) { + for (; k < this->perCoreRows_ && curRowsStart / this->k_ == i; curRowsStart++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); + if (outIndex < this->activateRows_) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expandedXGm_[outIndex * this->cols_], xLocal[(i - startXRow) * inFactor], + intriParams.blockCount, intriParams.blockLen); +#else + DataCopyPad(expandedXGm_[outIndex * this->cols_], xLocal[(i - startXRow) * inFactor], intriParams); +#endif + } + } + } + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + xCopyInQueue_.FreeTensor(xLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutEmpty() +{ + LocalTensor outLocal = expandedExpertIdxCopyOutQueue_.DeQue(); + expandedExpertIdxCopyOutQueue_.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingV2TilingData *tilingData, TPipe *tPipe) +{ + this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + //this->blockIdx_ = GetBlockIdx(); + this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num(); + this->n_ = tilingData->n; + this->k_ = tilingData->k; + this->cols_ = tilingData->cols; + this->sortNeedCoreNum_ = this->gatherOutTilingData_->needCoreNum; + this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; + if (needCoreNum_ == 0) { + this->sortNeedCoreNum_ = tilingData->vbsComputeParamsOp.needCoreNum; + } + this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; + this->activateRows_ = this->gatherOutTilingData_->activateRows; + if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) { + this->coreRows_ = this->gatherOutTilingData_->lastCoreRows; + } else { + this->coreRows_ = this->gatherOutTilingData_->perCoreRows; + } + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->dropPadMode = tilingData->dropPadMode; + this->expertNum = tilingData->expertNum; + + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->pipe = tPipe; + + xGm_.SetGlobalBuffer((__gm__ T *)x); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx, this->tileLength); + + expandedXGm_.SetGlobalBuffer((__gm__ T *)expandedX); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, this->tileLength); + if (this->expertTokensCountOrCumsumFlag > 0) { + // dropless + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); + } + + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum_ * sizeof(int32_t); + + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_; + pipe->InitBuffer(xCopyInQueue_, bufferNum, AlignBytes(this->cols_, sizeof(T)) * (endXRow - startXRow + 1)); + pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t))); + pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor); + +// sort310p +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + pipe->InitBuffer(tempBuffer, buffSize * REGIONP_ROPOSAL_KV_RATIO * kvFactor); + pipe->InitBuffer(sortedBuffer, buffSize * REGIONP_ROPOSAL_KV_RATIO * kvFactor); +#else + pipe->InitBuffer(tempBuffer, buffSize * kvFactor); + pipe->InitBuffer(sortedBuffer, buffSize * kvFactor); +#endif +} + +template +__aicore__ inline void MoeV2FullLoad::Process() +{ + if (this->blockIdx_ < this->sortNeedCoreNum_) { + CopyIn(); + SortCompute(); + if (this->blockIdx_ == 0) { + CopyOutIdx(); + } + if (this->blockIdx_ == this->sortNeedCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + ComputeExpertTokenCountOrCumsum(); + } else { + CopyOutEmpty(); + } + } + if (this->needCoreNum_ != 0) { + if (this->blockIdx_ < this->needCoreNum_) { + CopyOutX(); + } + + } else { + if (this->blockIdx_ < this->sortNeedCoreNum_) { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + } + } +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_FULL_LOAD_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort.h new file mode 100644 index 00000000..fef3005e --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort.h @@ -0,0 +1,211 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_mrgsort.h + * \brief + */ +#ifndef MOE_V2_MRGSORT_H +#define MOE_V2_MRGSORT_H + +#include "moe_v2_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +struct MoeV2MrgsortParam { + int64_t perListElements; + int64_t lastListElements; + int64_t oneLoopMaxElements; +}; + +class MoeV2Mrgsort { +public: + __aicore__ inline MoeV2Mrgsort(){}; + __aicore__ inline void Init(MoeV2MrgsortParam *param); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeV2MrgsortParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput; + + LocalTensor ubInputs[4]; + LocalTensor ubOutput; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail{0}; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeV2Mrgsort::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeV2Mrgsort::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput) +{ + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeV2Mrgsort::SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput) +{ + this->gmOutput = gmOutput; + this->ubOutput = ubOutput; +} + +__aicore__ inline void MoeV2Mrgsort::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeV2Mrgsort::CopyIn() +{ + this->remainListNum = 0; + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeV2Mrgsort::MrgsortCompute() +{ + SetWaitFlag(HardEvent::MTE2_V); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->ubOutput, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeV2Mrgsort::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeV2Mrgsort::CopyOut() +{ +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyExtParams intriParams; +#else + DataCopyParams intriParams; +#endif + intriParams.blockCount = 1; + intriParams.blockLen = GetSortLen(curLoopSortedNum) * sizeof(float); + SetWaitFlag(HardEvent::V_MTE3); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(this->gmOutput[outOffset], this->ubOutput, intriParams.blockCount, + intriParams.blockLen); +#else + DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams); +#endif + outOffset += GetSortLen(curLoopSortedNum); +} + +__aicore__ inline void MoeV2Mrgsort::Init(MoeV2MrgsortParam *param) +{ + this->param = param; + this->remainListNum = listNum; + + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeV2Mrgsort::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + CopyOut(); + } + + ClearCache(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_MRGSORT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort_out.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort_out.h new file mode 100644 index 00000000..fd1facdc --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_mrgsort_out.h @@ -0,0 +1,245 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_mrgsort_out.h + * \brief + */ +#ifndef MOE_V2_MRGSORT_OUT_H +#define MOE_V2_MRGSORT_OUT_H + +#include "moe_v2_mrgsort.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2MrgsortOut { +public: + __aicore__ inline MoeV2MrgsortOut(){}; + __aicore__ inline void Init(MoeV2MrgsortParam *param, TPipe *tPipe); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput1, GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2); + __aicore__ inline void SetBuffer(LocalTensor &tempBuffer); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void Extract(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeV2MrgsortParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput1; + GlobalTensor gmOutput2; + + LocalTensor ubInputs[4]; + LocalTensor tempBuffer; + + // for extract + LocalTensor ubOutput1; + LocalTensor ubOutput2; + + // for copy out + LocalTensor ubOutputInt1; + LocalTensor ubOutputInt2; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + bool needCopyOut = true; +#endif +}; + +__aicore__ inline void MoeV2MrgsortOut::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeV2MrgsortOut::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput) +{ + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeV2MrgsortOut::SetOutput(GlobalTensor &gmOutput1, GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2) +{ + this->gmOutput1 = gmOutput1; + this->ubOutput1 = ubOutput1; + this->ubOutputInt1 = ubOutput1.ReinterpretCast(); + + this->gmOutput2 = gmOutput2; + this->ubOutput2 = ubOutput2.ReinterpretCast(); + this->ubOutputInt2 = ubOutput2.ReinterpretCast(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + needCopyOut = !(this->gmOutput1.GetPhyAddr() == this->gmOutput2.GetPhyAddr()); +#endif +} + +__aicore__ inline void MoeV2MrgsortOut::SetBuffer(LocalTensor &tempBuffer) +{ + this->tempBuffer = tempBuffer; +} + +__aicore__ inline void MoeV2MrgsortOut::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeV2MrgsortOut::CopyIn() +{ + this->remainListNum = 0; + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeV2MrgsortOut::MrgsortCompute() +{ + SetWaitFlag(HardEvent::MTE2_V); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->tempBuffer, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeV2MrgsortOut::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeV2MrgsortOut::Extract() +{ + AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM)); + PipeBarrier(); + Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float))); + PipeBarrier(); + Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float))); +} + +__aicore__ inline void MoeV2MrgsortOut::CopyOut() +{ +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyExtParams intriParams; +#else + DataCopyParams intriParams; +#endif + intriParams.blockCount = 1; + intriParams.blockLen = curLoopSortedNum * sizeof(int32_t); + SetWaitFlag(HardEvent::V_MTE3); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + if (needCopyOut) { + DataCopyCustom(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams.blockCount, intriParams.blockLen); + } + DataCopyCustom(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams.blockCount, intriParams.blockLen); +#else + DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams); + DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams); +#endif + outOffset += curLoopSortedNum; +} + +__aicore__ inline void MoeV2MrgsortOut::Init(MoeV2MrgsortParam *param, TPipe *tPipe) +{ + this->param = param; + this->allRemainElements = 0; + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeV2MrgsortOut::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + Extract(); + CopyOut(); + } + ClearCache(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_MRGSORT_OUT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_base.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_base.h new file mode 100644 index 00000000..b20d6d3f --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_base.h @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_base.h + * \brief + */ +#ifndef MOE_V2_SORT_BASE_H +#define MOE_V2_SORT_BASE_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2SortBase { +public: + __aicore__ inline MoeV2SortBase(){}; + +protected: + __aicore__ inline void SyncAll(); + +protected: + TPipe *pipe; + TQue sortDataCopyInQueue; + TQue sortDataCopyOutQueue; + TBuf tempBuffer; + TBuf sortedBuffer; + + GlobalTensor expertIdxGm; + GlobalTensor sortedexpertIdxGm; + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + GlobalTensor syncTmpSpaceGm_; +#endif + int64_t tileLength; + int64_t bufferNum = 1; + int64_t totalLength; + int64_t coreNum; + int64_t n; + int64_t k; + int64_t existRowIdx; + int64_t expertNum; + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + + static constexpr int64_t SYNC_GM_NUM = 2; + static constexpr int64_t WORK_GM_NUM = 2; + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; +}; + +__aicore__ inline void MoeV2SortBase::SyncAll() +{ +#ifndef __CCE_KT_TEST__ + if (coreNum == 1) { + return; + } + AscendC::SyncAll(); +#endif +} + +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_SORT_BASE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_multi_core.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_multi_core.h new file mode 100644 index 00000000..eebe6284 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_multi_core.h @@ -0,0 +1,507 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_multi_core.h + * \brief + */ +#ifndef MOE_V2_VBS_ONE_CORE_H +#define MOE_V2_VBS_ONE_CORE_H + +#include "moe_v2_sort_base.h" +#include "moe_v2_mrgsort.h" +#include "moe_v2_mrgsort_out.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2SortMultiCore : public MoeV2SortBase { +public: + __aicore__ inline MoeV2SortMultiCore(){}; + template + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + __aicore__ inline void ResetIO(GM_ADDR expandedRowIdx, GM_ADDR workspace); +#endif +private: + __aicore__ inline void VBSProcess(); + __aicore__ inline void UBSortProcess(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void OneCoreVMSProcess(int64_t listNum, int64_t perListElements, int64_t lastListElements); + __aicore__ inline void VMSProcess(); + __aicore__ inline void SortOutProcess(); + __aicore__ inline void VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void UBSortCompute(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void InitMoeMrgSort(MoeV2Mrgsort *sorter, int64_t listNum, int64_t coreOffset, + int64_t loopOffset); + __aicore__ inline void InitMoeMrgSortOut(MoeV2MrgsortOut *sorter, int64_t listNum, int64_t coreOffset); + __aicore__ inline void InitExpertTokensGlobalMemory(); + +private: + GlobalTensor workspaceGms[2]; + + const MoeV2VBSComputeTilingData *vbsTilingData; + const MoeV2VMSMiddleComputeTilingData *vmsTilingData; + const MoeV2SortOutComputeTilingData *sortOutTilingData; + + // for MoeMrgsort + MoeV2Mrgsort mrgsorter; + MoeV2MrgsortParam mrgsortParam; + + int64_t blockIdx; + int64_t srcWsIndex = 0; + + int64_t listNum; + int64_t perListElements; + int64_t lastListElements; + + int64_t sortTotalLength; + int64_t sortCoreLoops; + int64_t sortCoreLoopElements; + int64_t sortCoreLastLoopElements; + + int64_t perCoreExpert; + int64_t needInitExpertCore; + int64_t currentCoreExpert; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + int64_t perCoreOffset; +#endif + static constexpr int64_t MAX_MRGSORT_LIST = 4; +}; + +__aicore__ inline void MoeV2SortMultiCore::InitExpertTokensGlobalMemory() +{ + if (this->blockIdx < this->needInitExpertCore) { + if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + InitGlobalMemory(expertTokensCountOrCumsumGm, Align(this->currentCoreExpert, sizeof(int32_t)), 0); +#else + InitGlobalMemory(expertTokensCountOrCumsumGm, currentCoreExpert, 0); +#endif + } + if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + InitGlobalMemory(expertTokensBeforeCapacityGm, Align(this->currentCoreExpert, sizeof(int32_t)), 0); +#else + InitGlobalMemory(expertTokensBeforeCapacityGm, currentCoreExpert, 0); +#endif + } + } +} + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 +__aicore__ inline void MoeV2SortMultiCore::ResetIO(GM_ADDR expandedRowIdx, GM_ADDR workspace) +{ + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(expandedRowIdx), + Align(this->totalLength, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(expandedRowIdx), + Align(this->totalLength, sizeof(int32_t))); + expertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace) + + Align(this->totalLength, sizeof(int32_t)) + this->blockIdx * perCoreOffset, + this->sortTotalLength); + this->srcWsIndex = 0; + this->needInitExpertCore = 0; +} +#endif + +__aicore__ inline void MoeV2SortMultiCore::VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum) +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + int64_t inOffset = progress * sortCoreLoopElements; + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(size * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams); +#else + DataCopyPad(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams); +#endif + LocalTensor rowIdxLocal = inLocal[sortNum]; + int64_t startValue = this->blockIdx * this->vbsTilingData->perCoreElements + inOffset; + SetWaitFlag(HardEvent::MTE3_S); + ArithProgression(rowIdxLocal, startValue, 1, size); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::UBSortCompute(int64_t progress, int64_t size, int64_t sortNum) +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertForSourceRowLocal = inLocal[0]; + LocalTensor expertForSourceRowLocalFp32; + + expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast(); + #if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_NONE, sortNum); + #else + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, sortNum); + #endif + PipeBarrier(); + Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, sortNum); + PipeBarrier(); + + int64_t duplicateNum = size % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = size - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(sortNum)); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[sortNum].ReinterpretCast(); + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor concatLocal; + SetWaitFlag(HardEvent::MTE3_V); + + LocalTensor outLocal = tempBuffer.Get(GetSortLen(sortNum)); + + Concat(concatLocal, expertForSourceRowLocalFp32, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM); + + PipeBarrier(); + + Sort(outLocal, concatLocal, sourceRowLocal, concatLocal, sortNum / ONE_REPEAT_SORT_NUM); + + SetWaitFlag(HardEvent::V_MTE3); +#else + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + + LocalTensor concatLocal = expertForSourceRowLocalFp32; + + Sort(outLocal, concatLocal, sourceRowLocal, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM); + + sortDataCopyOutQueue.EnQue(outLocal); +#endif + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum) +{ +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor outLocal = tempBuffer.Get(GetSortLen(sortNum)); +#else + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); +#endif + DataCopy(workspaceGms[0][this->blockIdx * GetSortLen(this->vbsTilingData->perCoreElements) + + GetSortLen(progress * sortCoreLoopElements)], + outLocal, Align(GetSortLen(size), sizeof(float))); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + SetWaitFlag(HardEvent::MTE3_V); +#else + sortDataCopyOutQueue.FreeTensor(outLocal); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSort(MoeV2Mrgsort *sorter, int64_t listNum, int64_t coreOffset, + int64_t loopOffset) +{ + GlobalTensor srcWsGm = workspaceGms[srcWsIndex][blockIdx * coreOffset + loopOffset]; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor inLocal = sortedBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); + LocalTensor outLocal = tempBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); +#else + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); +#endif + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + GlobalTensor dstWsGm = workspaceGms[1 - srcWsIndex][blockIdx * coreOffset + loopOffset]; + sorter->SetOutput(dstWsGm, outLocal); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ != 200 + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSortOut(MoeV2MrgsortOut *sorter, int64_t listNum, + int64_t coreOffset) +{ + GlobalTensor srcWsGm = workspaceGms[srcWsIndex]; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor inLocal = sortedBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); +#else + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); +#endif + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + + LocalTensor outLocalV = outLocal[this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST]; + sorter->SetOutput(this->sortedexpertIdxGm, this->expandDstToSrcRowGm, outLocal, outLocalV); + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor tempLocal = tempBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); + // buffer for Extract + sorter->SetBuffer(tempLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +#else + LocalTensor tempLocal = + sortedBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); + // buffer for Extract + sorter->SetBuffer(tempLocal); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::OneCoreVMSProcess(int64_t listNum, int64_t perListElements, + int64_t lastListElements) +{ + int64_t coreOffset = GetSortLen(this->vbsTilingData->perCoreElements); + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + + for (int64_t i = 0; listNum >= 1; i++) { + int64_t loops = (listNum + MAX_MRGSORT_LIST - 1) / MAX_MRGSORT_LIST; + int64_t remainListNum = listNum - (loops - 1) * MAX_MRGSORT_LIST; + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + + int64_t loopOffset = GetSortLen(mrgsortParam.perListElements * MAX_MRGSORT_LIST); + for (int64_t loop = 0; loop < loops - 1; loop++) { + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, loop * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, (loops - 1) * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + + listNum = loops; + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + + if (loops == 1) { + break; + } + } +} + +__aicore__ inline void MoeV2SortMultiCore::UBSortProcess(int64_t progress, int64_t size, int64_t sortNum) +{ + VBSCopyIn(progress, size, sortNum); + UBSortCompute(progress, size, sortNum); + VBSCopyOut(progress, size, sortNum); +} + +__aicore__ inline void MoeV2SortMultiCore::VBSProcess() +{ + if (this->blockIdx < this->vbsTilingData->needCoreNum) { + int64_t sortNum = Ceil(sortCoreLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + for (int64_t loop = 0; loop < sortCoreLoops - 1; loop++) { + UBSortProcess(loop, sortCoreLoopElements, sortNum); + } + + sortNum = Ceil(sortCoreLastLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + UBSortProcess(sortCoreLoops - 1, sortCoreLastLoopElements, sortNum); + + if (sortCoreLoops > 1) { + OneCoreVMSProcess(sortCoreLoops, sortCoreLoopElements, sortCoreLastLoopElements); + } + } + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor syncLocal = tempBuffer.Get(); + AscendC::SyncAll(syncTmpSpaceGm_, syncLocal, GetBlockNum()); +#else + SyncAll(); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::VMSProcess() +{ + int64_t currentStageNeedCoreNum = this->vmsTilingData->needCoreNum; + perListElements = this->vbsTilingData->perCoreElements; + lastListElements = this->vbsTilingData->lastCoreElements; + listNum = this->vbsTilingData->needCoreNum; + + for (; listNum > MAX_MRGSORT_LIST;) { + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + int64_t coreOffset = GetSortLen(perListElements * MAX_MRGSORT_LIST); + int64_t remainListNum = listNum - (currentStageNeedCoreNum - 1) * MAX_MRGSORT_LIST; + + if (this->blockIdx < currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } else if (this->blockIdx == currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + listNum = currentStageNeedCoreNum; + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor syncLocal = tempBuffer.Get(); + AscendC::SyncAll(syncTmpSpaceGm_, syncLocal, GetBlockNum()); +#else + SyncAll(); +#endif + } +} + +__aicore__ inline void MoeV2SortMultiCore::SortOutProcess() +{ + if (this->blockIdx < 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + + MoeV2MrgsortOut sorter; + InitMoeMrgSortOut(&sorter, listNum, GetSortLen(perListElements)); + sorter.Init(&mrgsortParam, pipe); + sorter.Process(); + } +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor syncLocal = tempBuffer.Get(); + AscendC::SyncAll(syncTmpSpaceGm_, syncLocal, GetBlockNum()); +#else + SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, + const TilingData *tilingData, TPipe *tPipe) +{ + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->vbsTilingData = &(tilingData->vbsComputeParamsOp); + this->vmsTilingData = &(tilingData->vmsMiddleComputeParamsOp); + this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + this->perCoreOffset = this->vbsTilingData->perCoreElements; +#endif + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->tileLength = this->vbsTilingData->perCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->perCoreElements; + if (this->blockIdx == tilingData->vbsComputeParamsOp.needCoreNum - 1) { + this->tileLength = this->vbsTilingData->lastCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->lastCoreElements; + } + this->n = tilingData->n; + this->k = tilingData->k; + this->expertNum = tilingData->expertNum; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + + // VBS param init + if (this->blockIdx == this->vbsTilingData->needCoreNum - 1) { + sortCoreLoops = this->vbsTilingData->lastCoreLoops; + sortCoreLoopElements = this->vbsTilingData->lastCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->lastCoreLastLoopElements; + } else { + sortCoreLoops = this->vbsTilingData->perCoreLoops; + sortCoreLoopElements = this->vbsTilingData->perCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->perCoreLastLoopElements; + } + + this->pipe = tPipe; + expertIdxGm.SetGlobalBuffer((__gm__ int32_t *)expertIdx + + this->blockIdx * tilingData->vbsComputeParamsOp.perCoreElements, + this->sortTotalLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace), + Align(this->totalLength, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace) + + Align(this->totalLength, sizeof(int32_t)), + Align(this->totalLength, sizeof(int32_t))); + + this->perCoreExpert = Align((this->expertNum + this->coreNum - 1) / this->coreNum, sizeof(int32_t)); + this->needInitExpertCore = (this->expertNum + this->perCoreExpert - 1) / this->perCoreExpert; + this->currentCoreExpert = this->perCoreExpert; + if (this->blockIdx == needInitExpertCore - 1) { + this->currentCoreExpert = this->expertNum - (this->needInitExpertCore - 1) * this->perCoreExpert; + } + if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensCountOrCumsum + + this->blockIdx * this->perCoreExpert, + this->currentCoreExpert); + } + if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { + expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensBeforeCapacity + + this->blockIdx * this->perCoreExpert, + this->currentCoreExpert); + } + // key and value + int64_t kvFactor = 2; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + int64_t workspaceLen = GetSortLen(Align(this->totalLength, sizeof(int32_t))); + workspaceGms[0].SetGlobalBuffer((__gm__ float *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + workspaceLen); + workspaceGms[1].SetGlobalBuffer((__gm__ float *)workspace + + Align(this->totalLength, sizeof(int32_t)) * 2 + workspaceLen, + workspaceLen); + +#else + workspaceGms[0].SetGlobalBuffer((__gm__ float *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + Align(this->totalLength, sizeof(int32_t)) * kvFactor); + workspaceGms[1].SetGlobalBuffer((__gm__ float *)workspace + + Align(this->totalLength, sizeof(int32_t)) * (kvFactor + 2), + Align(this->totalLength, sizeof(int32_t)) * kvFactor); +#endif + + int64_t bufferSize = Ceil(Max(this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST, sortCoreLoopElements), + ONE_REPEAT_SORT_NUM) * + ONE_REPEAT_SORT_NUM * sizeof(int32_t) * kvFactor; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + syncTmpSpaceGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(this->totalLength, sizeof(int32_t)) * 2 + 2 * workspaceLen, + INT32_ONE_BLOCK_NUM * GetBlockNum() * BLOCK_BYTES); + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize); + pipe->InitBuffer(tempBuffer, bufferSize * REGIONP_ROPOSAL_KV_RATIO); + pipe->InitBuffer(sortedBuffer, bufferSize* REGIONP_ROPOSAL_KV_RATIO); + LocalTensor syncLocal = tempBuffer.Get(); + Duplicate(syncLocal, 0, SYNC_LEN); + SetWaitFlag(HardEvent::V_MTE3); + DataCopy(syncTmpSpaceGm_, syncLocal, SYNC_LEN); +#else + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortedBuffer, bufferSize); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::Process() +{ + InitExpertTokensGlobalMemory(); + VBSProcess(); + VMSProcess(); + SortOutProcess(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_VBS_ONE_CORE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_one_core.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_one_core.h new file mode 100644 index 00000000..0f21e1e3 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_sort_one_core.h @@ -0,0 +1,226 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_one_core.h + * \brief + */ +#ifndef MOE_V2_SORT_ONE_CORE_H +#define MOE_V2_SORT_ONE_CORE_H + +#include "moe_v2_mrgsort.h" +#include "moe_v2_sort_base.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2SortOneCore : public MoeV2SortBase { +public: + __aicore__ inline MoeV2SortOneCore(){}; + template + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + __aicore__ inline void ResetIO(GM_ADDR expandedRowIdx, GM_ADDR workspace); + bool needCopy = true; +#endif + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOut(); + +private: + int64_t sortNum; + int64_t blockIdx; + int64_t needCoreNum; +}; + +__aicore__ inline void MoeV2SortOneCore::CopyIn() +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(this->totalLength * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyPadCustom(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); +#else + DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); +#endif + LocalTensor rowIdxLocal = inLocal[this->sortNum]; + ArithProgression(rowIdxLocal, 0, 1, this->sortNum); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SortOneCore::SortCompute() +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertForSourceRowLocal = inLocal[0]; + LocalTensor expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast(); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_NONE, this->tileLength); +#else + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); +#endif + PipeBarrier(); + Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, this->tileLength); + PipeBarrier(); + + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + + LocalTensor concatLocal = expertForSourceRowLocalFp32; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum)); + Concat(concatLocal, expertForSourceRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum)); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[this->sortNum].ReinterpretCast(); + Sort(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sortedExpertForSourceRowLocal = outLocal[0]; + LocalTensor expandDstToSrcRowLocal; + expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast(); + Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength); + PipeBarrier(); + + LocalTensor expertForSourceRowLocalInt32; + expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2SortOneCore::CopyOut() +{ + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + DataCopyCustom(expandDstToSrcRowGm, outLocal[this->sortNum], intriParams.blockCount, intriParams.blockLen); + if (this->needCopy) { + DataCopyCustom(sortedexpertIdxGm, outLocal[0], intriParams.blockCount, intriParams.blockLen); + } +#else + DataCopyPad(sortedexpertIdxGm, outLocal[0], intriParams); + DataCopyPad(expandDstToSrcRowGm, outLocal[this->sortNum], intriParams); +#endif + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 +__aicore__ inline void MoeV2SortOneCore::ResetIO(GM_ADDR expandedRowIdx, GM_ADDR workspace) +{ + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(expandedRowIdx), this->tileLength); + expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(expandedRowIdx), this->tileLength); + expertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace) + this->tileLength, this->tileLength); + this->expertTokensCountOrCumsumFlag = 0; + this->expertTokensBeforeCapacityFlag = 0; + this->needCopy = false; +} +#endif + +template +__aicore__ inline void MoeV2SortOneCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, + const TilingData *tilingData, TPipe *tPipe) +{ + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->pipe = tPipe; + this->n = tilingData->n; + this->k = tilingData->k; + this->expertNum = tilingData->expertNum; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + this->needCoreNum = tilingData->vbsComputeParamsOp.needCoreNum; + + expertIdxGm.SetGlobalBuffer((__gm__ int32_t *)expertIdx, this->tileLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace), this->tileLength); + expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace) + this->tileLength, + this->tileLength); + + if (this->blockIdx == this->coreNum - 1) { + if (this->expertTokensCountOrCumsumFlag > 0) { + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + InitGlobalMemory(expertTokensCountOrCumsumGm, Align(this->expertNum, sizeof(int32_t)), 0); +#else + InitGlobalMemory(expertTokensCountOrCumsumGm, this->expertNum, 0); +#endif + } + if (this->expertTokensBeforeCapacityFlag == 1) { + expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t *)expertTokensBeforeCapacity, + Align(this->expertNum, sizeof(int32_t))); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + InitGlobalMemory(expertTokensBeforeCapacityGm, Align(this->expertNum, sizeof(int32_t)), 0); +#else + InitGlobalMemory(expertTokensBeforeCapacityGm, this->expertNum, 0); +#endif + } + } + // key and value + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize); +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + syncTmpSpaceGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + 2 * this->tileLength, SYNC_LEN); + buffSize = GetSortLen(this->sortNum) * sizeof(int32_t); + pipe->InitBuffer(tempBuffer, buffSize); + pipe->InitBuffer(sortedBuffer, buffSize); + LocalTensor syncLocal = tempBuffer.Get(); + Duplicate(syncLocal, 0, SYNC_LEN); + SetWaitFlag(HardEvent::V_MTE3); + DataCopy(syncTmpSpaceGm_, syncLocal, SYNC_LEN); + PipeBarrier(); +#else + pipe->InitBuffer(tempBuffer, buffSize); + pipe->InitBuffer(sortedBuffer, buffSize); +#endif +} + +__aicore__ inline void MoeV2SortOneCore::Process() +{ + if (GetBlockIdx() < this->needCoreNum) { + CopyIn(); + SortCompute(); + CopyOut(); + } +#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200 + LocalTensor syncLocal = tempBuffer.Get(); + AscendC::SyncAll(syncTmpSpaceGm_, syncLocal, GetBlockNum()); +#else + this->SyncAll(); +#endif +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_SORT_ONE_CORE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op.h new file mode 100644 index 00000000..a0e214d6 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op.h @@ -0,0 +1,173 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_src_to_dst_op.h + * \brief + */ +#ifndef MOE_V2_SRC_TO_DST_H +#define MOE_V2_SRC_TO_DST_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2SrcToDstOp { +public: + __aicore__ inline MoeV2SrcToDstOp(){}; + template + __aicore__ inline void Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void Compute(int64_t progress); + __aicore__ inline void CopyOut(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + +private: + TPipe *pipe; + TQue copyInQueue; + TQue copyOutQueue; + TBuf assistBuffer; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandSrcToDstRowGm; + GlobalTensor assistGm; + + const MoeV2GatherOutComputeTilingData *srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; +}; + +__aicore__ inline void MoeV2SrcToDstOp::AssistInit() +{ +#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1 + OOMCheckAddrRange(assistGm.GetPhyAddr(), ASSIST_NUM * sizeof(int32_t)); +#endif + LocalTensor assistTensor = assistBuffer.Get(ASSIST_NUM); + DataCopy(assistTensor, assistGm, ASSIST_NUM); + SetWaitFlag(HardEvent::MTE2_V); + Adds(assistTensor, assistTensor, (int32_t)(this->blockIdx * this->srcToDstTilingData->perCoreRows), ASSIST_NUM); +} + +__aicore__ inline void MoeV2SrcToDstOp::CopyIn(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.AllocTensor(); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t))); + copyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::Compute(int64_t progress) +{ + LocalTensor outLocal = copyOutQueue.AllocTensor(); + LocalTensor assistTensor = assistBuffer.Get(ASSIST_NUM); + + PipeBarrier(); + int64_t loops = Ceil(currentLoopRows, ASSIST_INDEX_NUM); + for (int64_t i = 0; i < loops; i++) { + Adds(outLocal[i * ASSIST_NUM], assistTensor, + static_cast(this->perLoopRows * progress + i * ASSIST_INDEX_NUM), ASSIST_NUM); + } + PipeBarrier(); + copyOutQueue.EnQue(outLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::CopyOut() +{ + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = sizeof(int32_t); + uint32_t outOffset; + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + outOffset = inLocal.GetValue(idx); + DataCopyPad(expandSrcToDstRowGm[outOffset], outLocal[idx * INT32_ONE_BLOCK_NUM], intriParams); + } + + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::SyncAll() +{ + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SrcToDstOp::Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe) +{ + int64_t blockNum = GetBlockNum() * 2; + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp); + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + } + + expandSrcToDstRowGm.SetGlobalBuffer((__gm__ int32_t *)expandSrcToDstRow, Align(this->totalLength, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)) + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + assistGm.SetGlobalBuffer((__gm__ int32_t *)assist, ASSIST_NUM); + + pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES); + pipe->InitBuffer(copyOutQueue, 1, Ceil(this->perLoopRows, ASSIST_NUM) * ASSIST_NUM * BLOCK_BYTES); + pipe->InitBuffer(assistBuffer, ASSIST_NUM * sizeof(int32_t)); +} + +__aicore__ inline void MoeV2SrcToDstOp::Process() +{ + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + currentLoopRows = perLoopRows; + AssistInit(); + for (int64_t loop = 0; loop < loops - 1; loop++) { + CopyIn(loop); + Compute(loop); + CopyOut(); + } + currentLoopRows = lastLoopRows; + CopyIn(loops - 1); + Compute(loops - 1); + CopyOut(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_SRC_TO_DST_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op_simt.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op_simt.h new file mode 100644 index 00000000..364c8fd0 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_op_simt.h @@ -0,0 +1,96 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_src_to_dst_op_simt.h + * \brief + */ +#ifndef MOE_V2_SRC_TO_DST_SIMT_H +#define MOE_V2_SRC_TO_DST_SIMT_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +class MoeV2SrcToDstOpSimt { +public: + __aicore__ inline MoeV2SrcToDstOpSimt(){}; + template + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandDstToSrcRow, const TilingData *tilingData); + __aicore__ inline void Process(); + +private: + __aicore__ inline void SyncAll(); + __aicore__ inline void ComputeSimt() const; + +private: + __gm__ int32_t *expandDstToSrcRowGm_; + __gm__ int32_t *expandedRowIdxGm_; + const MoeV2GatherOutComputeTilingData *srcToDstTilingData_; + + int64_t coreNum_; + int64_t blockIdx_; + int64_t totalLength_; + int64_t perCoreRows_; + int64_t coreRows_; + int64_t startIndex_; + int64_t threadNum_; +}; + +__aicore__ inline void MoeV2SrcToDstOpSimt::SyncAll() +{ + if (coreNum_ == 1) { + return; + } + AscendC::SyncAll(); +} + +template +__aicore__ inline void MoeV2SrcToDstOpSimt::Init(GM_ADDR expandedRowIdx, GM_ADDR expandDstToSrcRow, + const TilingData *tilingData) +{ + this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num(); + this->coreNum_ = tilingData->coreNum; + this->totalLength_ = tilingData->n * tilingData->k; + this->srcToDstTilingData_ = &(tilingData->srcToDstComputeParamsOp); + this->perCoreRows_ = this->srcToDstTilingData_->perCoreRows; + if (this->blockIdx_ == this->srcToDstTilingData_->needCoreNum - 1) { + this->coreRows_ = this->srcToDstTilingData_->lastCoreRows; + } else { + this->coreRows_ = this->srcToDstTilingData_->perCoreRows; + } + startIndex_ = this->blockIdx_ * this->perCoreRows_; + this->threadNum_ = THREAD_NUM < this->coreRows_ ? THREAD_NUM : this->coreRows_; + + expandedRowIdxGm_ = (__gm__ int32_t *)expandedRowIdx; + expandDstToSrcRowGm_ = (__gm__ int32_t *)expandDstToSrcRow + Align(this->totalLength_, sizeof(int32_t)); +} + +__aicore__ inline void MoeV2SrcToDstOpSimt::ComputeSimt() const +{ + for (int32_t index = static_cast(Simt::GetThreadIdx()); index < static_cast(this->coreRows_); + index += static_cast(Simt::GetThreadNum())) { + int64_t srcIndex = index + this->startIndex_; + int64_t dstIndex = expandDstToSrcRowGm_[srcIndex]; + expandedRowIdxGm_[dstIndex] = srcIndex; + } +} + +__aicore__ inline void MoeV2SrcToDstOpSimt::Process() +{ + if (this->blockIdx_ < this->srcToDstTilingData_->needCoreNum) { + ParallelEXE(this->threadNum_, ComputeSimt); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_SRC_TO_DST_SIMT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_with_capacity.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_with_capacity.h new file mode 100644 index 00000000..aca7328b --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/moe_v2_src_to_dst_with_capacity.h @@ -0,0 +1,279 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_src_to_dst_with_capacity.h + * \brief + */ +#ifndef MOE_V2_SRC_TO_DST_WITH_CAPACITY_H +#define MOE_V2_SRC_TO_DST_WITH_CAPACITY_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingV2 { +using namespace AscendC; +using namespace optiling; + +template +class MoeV2SrcToDstWithCapacity { +public: + __aicore__ inline MoeV2SrcToDstWithCapacity(){}; + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const TilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyOutRemain(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + +private: + TPipe *pipe; + TQue copyInQueue; + TQue copyOutQueue; + TQue copyOutZeroQueue; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expandedXGm; + + LocalTensor outTmpLocal; + + const MoeV2GatherOutComputeTilingData *srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t expertCapacity; + int64_t expertNum; + int64_t cols; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + + int64_t tokenCount = 0; + int32_t lastExpertId = -1; + int32_t lastCoreExpertId = 0; + int32_t lastCoreExpertIdNum = 0; +}; + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::AssistInit() +{ + if constexpr (IsSameType::value) { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } else { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } + + if (this->blockIdx != 0) { + this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2); + this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1); + for (int64_t i = this->blockIdx - 2; i >= 0; i--) { + int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2); + if (lastExpertIdx < this->lastCoreExpertId) { + break; + } + int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1); + this->lastCoreExpertIdNum += lastExpertNum; + } + } +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyIn(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length); + DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length); + copyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyOut(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) { + continue; + } +#endif + DataCopyExtParams copyParams1{static_cast(1), static_cast(col * sizeof(T)), 0, + 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, + copyParams1); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyOutRemain() +{ + if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) { + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + return; + } + while (this->lastExpertId < this->expertNum) { + while (this->tokenCount < this->expertCapacity) { + int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(col * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + copyOutZeroQueue.FreeTensor(this->outTmpLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::SyncAll() +{ + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe) +{ + int64_t blockNum = GetBlockNum() * 2; + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp); + this->expertNum = tilingData->expertNum; + this->expertCapacity = tilingData->expertCapacity; + this->cols = tilingData->cols; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->lastCoreLoops; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->perCoreLoops; + } + this->perLoopCols = this->srcToDstTilingData->perLoopCols; + this->lastLoopCols = this->srcToDstTilingData->lastLoopCols; + this->colLoops = this->srcToDstTilingData->colLoops; + + int64_t length = Align(this->totalLength, sizeof(int32_t)); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, length); + expandedXGm.SetGlobalBuffer((__gm__ T *)expandedX, this->expertNum * this->expertCapacity * this->cols); + + expandedExpertIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t *)workspace + length + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t *)workspace + length * 2, this->coreNum * 2); + + pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2); + pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t))); + if constexpr (IsSameType::value) { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t))); + } else { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(T))); + } +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::Process() +{ + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + AssistInit(); + this->outTmpLocal = copyOutZeroQueue.DeQue(); + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOut(loop); + } + CopyOutRemain(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingV2 +#endif // MOE_V2_SRC_TO_DST_WITH_CAPACITY_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/tiling_base.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/tiling_base.h new file mode 100644 index 00000000..d705a1bc --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/moe_init_routing_v2/tiling_base.h @@ -0,0 +1,66 @@ +#pragma once +namespace optiling { +struct AiCoreParams { + uint64_t ubSize; + uint64_t blockDim; + uint64_t aicNum; + + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; +}; + +class TilingBaseClass { +public: + bool DoTiling( + int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0, + int64_t aivCoreNum, int64_t ubSizePlatForm) + { + bool ret = GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode, expertTokensCountOrCumsumFlag, + expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0); + + if (!ret){ + return ret; + } + ret = GetPlatformInfo(aivCoreNum, ubSizePlatForm); + if (!ret){ + return ret; + } + ret = DoOpTiling(); + if (!ret){ + return ret; + } + ret = GetWorkspaceSize(); + if (!ret){ + return ret; + } + // ret = PostTiling(); + // if (!ret){ + // return ret; + // } + tilingKey_ = GetTilingKey(); + + return true; + } + +//protected: + virtual bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) = 0; + virtual bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) = 0; + + virtual bool DoOpTiling() = 0; + virtual bool GetWorkspaceSize() = 0; + // virtual bool PostTiling() = 0; + virtual uint64_t GetTilingKey() const = 0; +//protected: + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0}; +}; + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h new file mode 100644 index 00000000..1255b5cf --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute.h @@ -0,0 +1,376 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file moe_token_unpermute.h + * \brief + */ + +#ifndef MOE_TOKEN_UNPERMUTE +#define MOE_TOKEN_UNPERMUTE + +#include "kernel_operator.h" +#include "moe_token_unpermute_tiling.h" +using namespace AscendC; + + +template class KernelMoeTokenUnpermute { +public: + __aicore__ inline KernelMoeTokenUnpermute() + { + } + + __aicore__ inline void Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs, + GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data); + __aicore__ inline void Process(); + +protected: + __aicore__ inline void CalMultiOutToken(const int64_t out_offset, const int64_t out_tokens_number); + __aicore__ inline void CalSingleOutToken(const int64_t start_token, const int64_t out_token_idx); + __aicore__ inline void CalPartOutToken(const int64_t start_token, const int64_t h_index, const int64_t h_length, + const int64_t out_token_index); + __aicore__ inline void CopyTokenIn(const T2 in_token_index, const int64_t h_index, const int64_t h_length); + __aicore__ inline void CalFirstToken(const float prob_value, const int64_t h_length); + __aicore__ inline void CalToken(const float prob_value, const int64_t h_length); + __aicore__ inline void CopyOut(const int64_t out_token_index, const int64_t h_index, const int64_t h_length); + + TPipe pipe; + TQue tokens_inque, indices_inque, probs_inque; + TBuf temp_buffer0, temp_buffer1, temp_buffer2; + TQue outque; + GlobalTensor tokensGM, outGM; + GlobalTensor indicesGM; + GlobalTensor probsGM; + LocalTensor indicesLocal; + LocalTensor token_tensor0, token_tensor1, probs_tensor; + DataCopyPadExtParams extParams1{false, 0, 0, 0}; + DataCopyPadExtParams extParams2{false, 0, 0, 0}; + DataCopyPadExtParams extParams3{false, 0, 0, 0}; + DataCopyExtParams copyParams{1, 0, 0, 0, 0}; + + constexpr static uint32_t BLOCK_SIZE = 32; + constexpr static uint32_t ALIGN_512 = 512; + + int64_t hidden_size; + int64_t top_k; + int64_t num_out_tokens; + int64_t hidden_splited_length; + int64_t hidden_splited_num; + int64_t hidden_splited_remain; + int64_t tokens_core_length; + int64_t tokens_core_remain; + int64_t tokens_splited_length; + int64_t tokens_splited_num; + int64_t tokens_splited_remain; + int32_t blockIdx; + int32_t blockNum; +}; + +template +__aicore__ inline void +KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs, + GM_ADDR unpermuted_tokens, + const MoeTokenUnpermuteTilingData *__restrict tiling_data) +{ + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->blockNum = get_block_num() * get_subblockdim(); + + if (blockIdx >= blockNum) { + return; + } + ASSERT(blockNum != 0 && "block dim can not be zero!"); + // row_input + this->hidden_size = tiling_data->hidden_size; + this->top_k = tiling_data->top_k; + this->num_out_tokens = tiling_data->num_out_tokens; + // hidden_tiling + this->hidden_splited_length = tiling_data->hidden_splited_length; + this->hidden_splited_num = tiling_data->hidden_splited_num; + this->hidden_splited_remain = tiling_data->hidden_splited_remain; + // token_tiling + this->tokens_core_length = tiling_data->tokens_core_length; + this->tokens_core_remain = tiling_data->tokens_core_remain; + this->tokens_splited_length = tiling_data->tokens_splited_length; + this->tokens_splited_num = tiling_data->tokens_splited_num; + this->tokens_splited_remain = tiling_data->tokens_splited_remain; + + // Handle the tail block for token_by_core + if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) { + this->tokens_core_length += 1; + this->tokens_splited_remain += 1; + } + + int64_t hidden_splited_length_align512 = (this->hidden_splited_length + ALIGN_512 - 1) & ~(ALIGN_512 - 1); + + int64_t block_length = this->tokens_core_length * this->top_k; + int64_t block_splited_length = this->tokens_splited_length * this->top_k; + + int64_t block_offset; + if (this->tokens_core_remain > 0) { + if (blockIdx < this->tokens_core_remain) { + block_offset = block_length * blockIdx; + } else { + block_offset = (block_length + this->top_k) * this->tokens_core_remain + + block_length * (blockIdx - this->tokens_core_remain); + } + } else { + block_offset = block_length * blockIdx; + } + + this->tokensGM.SetGlobalBuffer((__gm__ T1 *)permuted_tokens); + this->indicesGM.SetGlobalBuffer((__gm__ T2 *)sorted_indices + block_offset, block_length); + + + int64_t out_block_offset; + if (this->tokens_core_remain > 0) { + if (blockIdx < this->tokens_core_remain) { + out_block_offset = this->tokens_core_length * blockIdx * hidden_size; + } else { + out_block_offset = (this->tokens_core_length + 1) * this->tokens_core_remain + + this->tokens_core_length * (blockIdx - this->tokens_core_remain); + out_block_offset *= this->hidden_size; + } + } else { + out_block_offset = this->tokens_core_length * blockIdx * hidden_size; + } + + this->outGM.SetGlobalBuffer((__gm__ T1 *)unpermuted_tokens + out_block_offset, + this->tokens_core_length * this->hidden_size); + + this->pipe.InitBuffer(tokens_inque, tiling_data->buffer_num, hidden_splited_length_align512 * sizeof(T1)); + this->pipe.InitBuffer(indices_inque, 1, block_splited_length * (sizeof(T2))); + this->pipe.InitBuffer(outque, 1, hidden_splited_length_align512 * sizeof(T1)); + + if constexpr (!IsSameType::value) { + this->pipe.InitBuffer(temp_buffer0, hidden_splited_length_align512 * sizeof(float) + 256); + this->pipe.InitBuffer(temp_buffer1, hidden_splited_length_align512 * sizeof(float)); + this->token_tensor0 = this->temp_buffer0.template Get(); + this->token_tensor1 = this->temp_buffer1.template Get(); + } + + if constexpr (PROBS) { + this->probsGM.SetGlobalBuffer((__gm__ T3 *)probs + block_offset, block_length); + this->pipe.InitBuffer(probs_inque, 1, block_splited_length * (sizeof(T3))); + if constexpr (!IsSameType::value) { + this->pipe.InitBuffer(temp_buffer2, block_splited_length * sizeof(float)); + this->probs_tensor = this->temp_buffer2.template Get(); + } + } +}; + +template +__aicore__ inline void KernelMoeTokenUnpermute::Process() +{ + + if (blockIdx >= blockNum) { + return; + } + for (int64_t i = 0; i < this->tokens_splited_num; ++i) { + CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length); + } + // Handle the tail block when tokens_num is not evenly divisible by core count + if (this->tokens_splited_remain > 0) { + CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalMultiOutToken(const int64_t out_offset, + const int64_t out_tokens_number) +{ + this->indicesLocal = this->indices_inque.template AllocTensor(); + int64_t in_offset = out_offset * this->top_k; + this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T2); + DataCopyPad(this->indicesLocal, this->indicesGM[in_offset], this->copyParams, this->extParams2); + this->indices_inque.template EnQue(this->indicesLocal); + + if constexpr (PROBS) { + LocalTensor temp_probs_tensor = this->probs_inque.template AllocTensor(); + this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T3); + DataCopyPad(temp_probs_tensor, this->probsGM[in_offset], this->copyParams, this->extParams3); + this->probs_inque.template EnQue(temp_probs_tensor); + temp_probs_tensor = this->probs_inque.template DeQue(); + if constexpr (!IsSameType::value) { + Cast(this->probs_tensor, temp_probs_tensor, RoundMode::CAST_NONE, out_tokens_number * this->top_k); + this->probs_inque.FreeTensor(temp_probs_tensor); + PipeBarrier(); + } else { + this->probs_tensor = temp_probs_tensor; + } + } + this->indicesLocal = this->indices_inque.template DeQue(); + + + for (int64_t out_token_idx = 0; out_token_idx < out_tokens_number; ++out_token_idx) { + CalSingleOutToken(out_token_idx * this->top_k, out_offset + out_token_idx); + } + // Free Tensor + this->indices_inque.FreeTensor(this->indicesLocal); + if constexpr (PROBS && IsSameType::value) { + this->probs_inque.FreeTensor(this->probs_tensor); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalSingleOutToken(const int64_t start_token, + const int64_t out_token_idx) +{ + for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) { + CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx); + } + // Handle the tail block when a full hidden_size does not fit in one pass + if (this->hidden_splited_remain > 0) { + CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx); + } +} + +template +__aicore__ inline void +KernelMoeTokenUnpermute::CalPartOutToken(const int64_t start_token, const int64_t h_index, + const int64_t h_length, const int64_t out_token_index) +{ + if constexpr (IsSameType::value) { + this->token_tensor0 = this->outque.template AllocTensor(); + } + int64_t end_token = start_token + this->top_k; + T2 cal_token_idx = this->indicesLocal.GetValue(start_token); + + // Handle the first token + if (cal_token_idx < this->num_out_tokens) { + float probsValue = 0; + if constexpr (PROBS) { + probsValue = this->probs_tensor.GetValue(start_token); + } + + CopyTokenIn(cal_token_idx, h_index, h_length); + PipeBarrier(); + CalFirstToken(probsValue, h_length); + } else { + PipeBarrier(); + Duplicate(this->token_tensor0, static_cast(0), h_length); + } + + // Handle the remaining tokens + for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) { + cal_token_idx = this->indicesLocal.GetValue(token_index); + if (cal_token_idx < this->num_out_tokens) { + float probsValue = 0; + if constexpr (PROBS) { + probsValue = this->probs_tensor.GetValue(token_index); + } + + CopyTokenIn(cal_token_idx, h_index, h_length); + PipeBarrier(); + CalToken(probsValue, h_length); + } + } + + // Write out the computed result + CopyOut(out_token_index, h_index, h_length); +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CopyTokenIn(const T2 in_token_index, + const int64_t h_index, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template AllocTensor(); + int64_t offset = in_token_index * this->hidden_size + h_index * this->hidden_splited_length; + + if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) { + DataCopy(tokensLocal, this->tokensGM[offset], h_length); + } else { + this->copyParams.blockLen = h_length * sizeof(T1); + DataCopyPad(tokensLocal, this->tokensGM[offset], this->copyParams, this->extParams1); + } + + this->tokens_inque.template EnQue(tokensLocal); +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalFirstToken(const float prob_value, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template DeQue(); + + if constexpr (!IsSameType::value) { + Cast(this->token_tensor0, tokensLocal, RoundMode::CAST_NONE, h_length); + } else { + uint64_t byteAlign32 = (h_length * sizeof(float) + BLOCK_SIZE - 1) & ~(BLOCK_SIZE - 1); + DataCopy(this->token_tensor0, tokensLocal, byteAlign32 / sizeof(float)); + } + + this->tokens_inque.FreeTensor(tokensLocal); + + if constexpr (PROBS) { + PipeBarrier(); + Muls(this->token_tensor0, this->token_tensor0, prob_value, h_length); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalToken(const float prob_value, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template DeQue(); + + if constexpr (!IsSameType::value) { + Cast(this->token_tensor1, tokensLocal, RoundMode::CAST_NONE, h_length); + this->tokens_inque.FreeTensor(tokensLocal); + if constexpr (PROBS) { + PipeBarrier(); + Muls(this->token_tensor1, this->token_tensor1, prob_value, h_length); + } + PipeBarrier(); + Add(this->token_tensor0, this->token_tensor0, this->token_tensor1, h_length); + } else { + if constexpr (PROBS) { + Muls(tokensLocal, tokensLocal, prob_value, h_length); + PipeBarrier(); + } + Add(this->token_tensor0, this->token_tensor0, tokensLocal, h_length); + this->tokens_inque.FreeTensor(tokensLocal); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CopyOut(const int64_t out_token_index, + const int64_t h_index, + const int64_t h_length) +{ + LocalTensor temp_out_tensors; + if constexpr (!IsSameType::value) { + temp_out_tensors = this->outque.template AllocTensor(); + PipeBarrier(); + Cast(temp_out_tensors, this->token_tensor0, RoundMode::CAST_RINT, h_length); + } else { + temp_out_tensors = this->token_tensor0; + } + + this->outque.template EnQue(temp_out_tensors); + temp_out_tensors = this->outque.template DeQue(); + + int64_t offset = out_token_index * this->hidden_size + h_index * this->hidden_splited_length; + if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) { + DataCopy(this->outGM[offset], temp_out_tensors, h_length); + } else { + this->copyParams.blockLen = h_length * sizeof(T1); + DataCopyPad(this->outGM[offset], temp_out_tensors, this->copyParams); + } + + this->outque.FreeTensor(temp_out_tensors); +} +#endif // MOE_TOKEN_UNPERMUTE diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute_tiling.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute_tiling.h new file mode 100644 index 00000000..df47f6db --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/unpermute/moe_token_unpermute_tiling.h @@ -0,0 +1,38 @@ +#ifndef MOE_TOKEN_UNPERMUTE_TILING +#define MOE_TOKEN_UNPERMUTE_TILING + +struct MoeTokenUnpermuteTilingData { + int64_t hidden_size; + int64_t top_k; + int64_t num_out_tokens; + int64_t hidden_splited_length; + int64_t hidden_splited_num; + int64_t hidden_splited_remain; + int64_t tokens_core_length; + int64_t tokens_core_remain; + int64_t tokens_splited_length; + int64_t tokens_splited_num; + int64_t tokens_splited_remain; + int64_t buffer_num; +}; + +__forceinline__ [host, aicore] void +MoeTokenUnpermuteTiling(int32_t m, int32_t n, int32_t topK, MoeTokenUnpermuteTilingData &tilingData, uint32_t coreNum) +{ + #define I64(x) static_cast(x) + tilingData.hidden_size = I64(n); + tilingData.top_k = I64(topK); + tilingData.num_out_tokens = I64(m); + tilingData.hidden_splited_length = tilingData.hidden_size; + tilingData.hidden_splited_num = 1; + tilingData.hidden_splited_remain = 0; + uint32_t outTokens = m / topK; + tilingData.tokens_core_length = I64(outTokens / coreNum); + tilingData.tokens_core_remain = I64(outTokens % coreNum); + tilingData.tokens_splited_length = I64(min(tilingData.tokens_core_length, 600)); + tilingData.tokens_splited_num = I64(tilingData.tokens_core_length / tilingData.tokens_splited_length); + tilingData.tokens_splited_remain = I64(tilingData.tokens_core_length % tilingData.tokens_splited_length); + tilingData.buffer_num = 4; +} + +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_row.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_row.hpp new file mode 100644 index 00000000..4d949ead --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_row.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/epilogue/block/block_epilogue.hpp" + +namespace Catlass::Epilogue::Block { + +// float scale, dequant per expert +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequant, + CType_, + Gemm::GemmType, + DType_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + (std::is_same_v || std::is_same_v) && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong" + ); + static_assert( + std::is_same_v && + std::is_same_v && std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong" + ); + + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + struct Params { + __gm__ int32_t *ptrTokenPerExpert{nullptr}; + int32_t EP; + int32_t expertPerRank; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 4096; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + constexpr int32_t blockN = 12000; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + ubCFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(float); + } + } + CATLASS_DEVICE + void Finalize() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + CATLASS_DEVICE + ~BlockEpilogue() + { + + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + MatrixCoord const &shapeC, + AscendC::GlobalTensor const &gmPerTokenScale, + AscendC::GlobalTensor const &gmD + ) + { + uint32_t blockM = shapeC.row(); + uint32_t blockN = shapeC.column(); + + uint32_t tileLoops = blockM; + + for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) { + auto gmTileC = gmC[loopIdx * blockN]; + auto &ubC = ubCList[ubListId]; + auto &ubCFp32 = ubCFp32List[ubListId]; + auto &ubMul = ubMulList[ubListId]; + auto &ubD = ubDList[ubListId]; + auto gmTileD = gmD[loopIdx * blockN]; + LayoutC layoutUbC{1, blockN}; + + // Move C from GM workspace to UB + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + // Cast C to FP32 in UB + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + // Get per-token scale from row loopIdx of gmPerTokenScale + ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + // Multiply FP32 C by the per-token scale + AscendC::PipeBarrier(); + AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN); + AscendC::PipeBarrier(); + + // Cast the muls result back to fp16/bf16 + LayoutD layoutUbD{1, blockN}; + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + + AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32List[UB_STAGES]; + AscendC::LocalTensor ubMulList[UB_STAGES]; + + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp new file mode 100644 index 00000000..0862458f --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp @@ -0,0 +1,402 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +namespace Catlass::Epilogue::Block { + +// float scale, dequant per expert +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileElemWiseMuls_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequantSwigluQuant, + CType_, + Gemm::GemmType, + DType_, + TileElemWiseMuls_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + (std::is_same_v || std::is_same_v) && + (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong" + ); + static_assert( + std::is_same_v && + std::is_same_v && std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong" + ); + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm>; + + struct Params { + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_ + ) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), layoutD(layoutD_) {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + constexpr uint32_t blockN = 4096; + constexpr uint32_t ChunkTileLen = blockN / 2; + constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2; + + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementD); + ubCFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(float); + ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ChunkTileLen * sizeof(float); + ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ChunkTileLen * sizeof(float); + ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HalfChunkTileLen * sizeof(float); + ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast(); + ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast(); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + + ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + } + CATLASS_DEVICE + void Finalize() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + CATLASS_DEVICE + ~BlockEpilogue() + { + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + // 每个tile就是1*7168,每个block是一个expert的所有token=[group[i], 7168] + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + MatrixCoord const &shapeC, + AscendC::GlobalTensor const &gmPerTokenScale1, + AscendC::GlobalTensor const &gmD, + AscendC::GlobalTensor const &gmPerTokenScale2, + + uint32_t epilogueCoreNum = 40, + Callback &&callback = Callback{} + ) + { + callback(); + uint32_t blockM = shapeC.row(); + uint32_t blockN = shapeC.column(); + + uint32_t tileLoops = blockM; + uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + uint32_t subblockNum = get_block_num() * 2; + uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum; + + if (subblockIdx < moveDataCoreNum) { + return; + } + uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum; + + uint32_t perCoreData = blockM / epilogueCoreNum; + uint32_t remainderData = blockM % epilogueCoreNum; + + uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData; + uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData); + + uint32_t alignedPerCoreData = RoundUp(perCoreData + 1); + + uint32_t ChunkTileLen = blockN / 2; + uint32_t HalfChunkTileLen = ChunkTileLen / 2; + + + for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) { + + auto gmTileC = gmC[loopIdx * blockN]; + + auto &ubC = ubCList[ubListId]; + auto &ubD = ubDList[ubListId]; + + auto &ubCFp32 = ubCFp32List[ubListId]; + auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId]; + auto &ubAbs = ubCFp32ChunkNAbsList[ubListId]; + // auto &ubMax = ubCFp32ChunkNMaxList[ubListId]; + auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId]; + auto &ubOutputTmp = ubAbs; + auto &sharedUbTmpBuffer = ubReduceMax; + auto &ubQuantS32 = ubQuantS32List[ubListId]; + auto &ubQuantF16 = ubQuantF16List[ubListId]; + + auto gmTileD = gmD[loopIdx * ChunkTileLen]; + LayoutC layoutUbC{1, blockN}; + + // 把C从GM workspace搬到UB + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + // 在UB上做把C cast成FP32 + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + // 获取pertoken scale值,gmPerTokenScale的第loopIdx行 + ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + // pertoken scale值与FP32的C做Muls乘法 + AscendC::PipeBarrier(); + AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN); + AscendC::PipeBarrier(); + + //swiglue计算过程 + AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + //TODO除的时候是否会对之后的数据有影响; + AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen); + + //quant过程,两种方式区别; + AscendC::PipeBarrier(); + AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::ReduceMax(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false); + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + //TODO两种计算方法的效率比较 + ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0); + AscendC::SetFlag(0); + + auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx; + ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f); + + AscendC::WaitFlag(0); + AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::SetDeqScale(static_cast(1.0)); + AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + // AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + LayoutD layoutUbD{1, ChunkTileLen}; + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + + if(tasksForIdx > 0){ + LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx}; + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2); + } + + + } + + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + MatrixCoord const &shapeC, + AscendC::GlobalTensor const &gmD, + uint32_t epilogueCoreNum = 40, + Callback &&callback = Callback{} + ) + { + callback(); + uint32_t blockM = shapeC.row(); + uint32_t blockN = shapeC.column(); + + uint32_t tileLoops = blockM; + uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num(); + //uint32_t subblockIdx = get_block_idx() * 2 + get_subblockid(); + + uint32_t subblockNum = get_block_num() * 2; + uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum; + + if (subblockIdx < moveDataCoreNum) { + return; + } + uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum; + + + uint32_t perCoreData = blockM / epilogueCoreNum; + uint32_t remainderData = blockM % epilogueCoreNum; + + uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData; + uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData); + + uint32_t alignedPerCoreData = RoundUp(perCoreData + 1); + + uint32_t ChunkTileLen = blockN / 2; + uint32_t HalfChunkTileLen = ChunkTileLen / 2; + + + for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) { + + auto gmTileC = gmC[loopIdx * blockN]; + + auto &ubC = ubCList[ubListId]; + auto &ubD = ubDList[ubListId]; + + auto &ubCFp32 = ubCFp32List[ubListId]; + auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId]; + + auto gmTileD = gmD[loopIdx * ChunkTileLen]; + LayoutC layoutUbC{1, blockN}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + AscendC::Cast(ubD, ubCFp32ChunkN, AscendC::RoundMode::CAST_ROUND, ChunkTileLen); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + LayoutD layoutUbD{1, ChunkTileLen}; + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + // copyUbToGmD(gmTileD, ubCFp32ChunkN, layoutUbD, layoutUbD); + copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + + } + + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32List[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNList[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNAbsList[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNMaxList[UB_STAGES]; + AscendC::LocalTensor ubQuantS32List[UB_STAGES]; + AscendC::LocalTensor ubQuantF16List[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleOutput; + + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; + CopyUbToGmDequantScale copyUbToGmDequantScale; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp new file mode 100644 index 00000000..eaab8104 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_v2.hpp @@ -0,0 +1,330 @@ +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "hccl_shmem.hpp" +#include "layout3d.hpp" + +namespace Catlass::Epilogue::Block { +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequantV2, + CType_, + Gemm::GemmType, + DType_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + //using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub>; + using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub>; + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + struct Params { + __gm__ int32_t *ptrTokenPerExpert{nullptr}; + int32_t EP; + int32_t expertPerRank; + int32_t n2; + LayoutC layoutC; + int32_t n0; + int32_t rank; + HcclShmem shmem; + int32_t offsetD; + + CATLASS_DEVICE + Params() {}; + CATLASS_DEVICE + Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_, + LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) : + ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), + expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_), + shmem(shmem_), offsetD(offsetD_) + {} + }; + + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + + + + //ub:192KB + n0 = params.n0; + size_t ubOffset = 0; + for(int32_t i = 0; i < 2; i++) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementD); + ubFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(float); + scaleUbList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += (max_len / n0) * sizeof(float); + source_scale_offset[i] = -1; + } + tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert)); + tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); + is_ping = true; + } + + CATLASS_DEVICE + void Finalize() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + } + CATLASS_DEVICE + ~BlockEpilogue() + { + + } + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + AscendC::GlobalTensor const &gmPerTokenScale, + GemmCoord& blockCoord, + GemmCoord& actualBlockShape, + int32_t groupIdx, + int32_t preSrcExpertSum, + AscendC::GlobalTensor preSumBeforeRank, + uint32_t *mPreSumBeforeRank + ){ + is_ping = !is_ping; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3; + + auto &ubC = ubCList[is_ping]; + auto &ubD = ubDList[is_ping]; + int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); + auto gmTileC = gmC[gmCOffset]; + auto &ubCFp32 = ubFp32List[is_ping]; + auto &scaleUb = scaleUbList[is_ping]; + // auto &ubOutFp32 = ubOutFp32List[is_ping]; + + LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2}; + LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; + + + AscendC::WaitFlag(event_id); //for debug + copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); + AscendC::SetFlag(event_id); //for debug + + AscendC::WaitFlag(event_id); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4}); + AscendC::SetFlag(event_id); + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); + + int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m(); + layout::VectorLayout scaleLauout{actualBlockShape.m()}; + if (source_scale_offset[event_id] != gmScaleOffset) { + source_scale_offset[event_id] = gmScaleOffset; + copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout); + } + + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + + + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码 + AscendC::PipeBarrier(); + for (int32_t row = 0; row < actualBlockShape.m(); ++row) { + float scale = scaleUb(row); + Muls(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8}); + } + AscendC::PipeBarrier(); + AscendC::WaitFlag(event_id); + AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8}); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id); + + int32_t lenTile = actualBlockShape.m(); + int32_t stTile = blockCoord.m(); + int32_t edTile = stTile + lenTile; + int32_t preSumRankInExpert = 0; + int32_t tileOffset = 0; + + AscendC::WaitFlag(event_id); //for debug + for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { + int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16); + int32_t stRankInExpert = preSumRankInExpert; + int32_t edRankInExpert = stRankInExpert + lenRankInExpert; + preSumRankInExpert += lenRankInExpert; + if (stRankInExpert >= edTile) { + break; + } + else if (edRankInExpert <= stTile) { + continue; + } + int32_t stData = max(stRankInExpert, stTile); + int32_t edData = min(edRankInExpert, edTile); + uint32_t lenData = edData - stData; + if (lenData <= 0){ + continue; + } + + uint32_t dstOffsetInExpert = 0; + if (stTile > stRankInExpert) { + dstOffsetInExpert = stTile - stRankInExpert; + } + AscendC::GlobalTensor gmRemotePeer; + __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); + MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()}; + int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); + auto gmTileD = gmRemotePeer[gmDstOffset]; + LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; + LayoutC layoutUB2{lenData, actualBlockShape.n(), n0}; + copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2); + tileOffset += lenData; + } + AscendC::SetFlag(event_id); + + } + + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + GemmCoord& blockCoord, + GemmCoord& actualBlockShape, + int32_t groupIdx, + int32_t preSrcExpertSum, + AscendC::GlobalTensor preSumBeforeRank, + uint32_t *mPreSumBeforeRank + ){ + is_ping = !is_ping; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + + auto &ubC = ubCList[is_ping]; + auto &ubD = ubDList[is_ping]; + int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); + auto gmTileC = gmC[gmCOffset]; + + LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2}; + LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; + + + AscendC::SetFlag(event_id); //for debug + copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); + AscendC::SetFlag(event_id); //for debug + + int32_t lenTile = actualBlockShape.m(); + int32_t stTile = blockCoord.m(); + int32_t edTile = stTile + lenTile; + int32_t preSumRankInExpert = 0; + int32_t tileOffset = 0; + + AscendC::WaitFlag(event_id); //for debug + for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { + int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16); + int32_t stRankInExpert = preSumRankInExpert; + int32_t edRankInExpert = stRankInExpert + lenRankInExpert; + preSumRankInExpert += lenRankInExpert; + if (stRankInExpert >= edTile) { + break; + } + else if (edRankInExpert <= stTile) { + continue; + } + int32_t stData = max(stRankInExpert, stTile); + int32_t edData = min(edRankInExpert, edTile); + uint32_t lenData = edData - stData; + if (lenData <= 0){ + continue; + } + + uint32_t dstOffsetInExpert = 0; + if (stTile > stRankInExpert) { + dstOffsetInExpert = stTile - stRankInExpert; + } + AscendC::GlobalTensor gmRemotePeer; + __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); + MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()}; + int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); + auto gmTileD = gmRemotePeer[gmDstOffset]; + LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; + LayoutC layoutUB2{lenData, actualBlockShape.n(), n0}; + copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2); + tileOffset += lenData; + } + AscendC::WaitFlag(event_id); + + } + +private: + + Params params; + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + AscendC::LocalTensor ubFp32List[UB_STAGES]; + AscendC::LocalTensor scaleUbList[UB_STAGES]; + int32_t source_scale_offset[UB_STAGES]; + + int32_t max_len = 8 * 32 / 4 * 128; + int32_t n0; + bool is_ping = false; + + + int32_t repeat = 128; + + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; + + CopyScaleGmToUb copyScaleGmToUb; + AscendC::GlobalTensor tokenPerExpert; + Layout3D tokenPerExpertLayout; +}; +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp new file mode 100644 index 00000000..c15e11b2 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "dispatch_policy_custom.hpp" + + +namespace Catlass::Gemm::Block { + +template +__aicore__ inline void SyncFlagFunc(int32_t eventID) +{ + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +template < + uint32_t PRELOAD_STAGES_, + uint32_t L1_STAGES_, + uint32_t L0A_STAGES_, + uint32_t L0B_STAGES_, + uint32_t L0C_STAGES_, + bool ENABLE_UNIT_FLAG_, + bool ENABLE_SHUFFLE_K_, + class L1TileShape_, + class L0TileShape_, + class AType_, + class BType_, + class CType_, + class BiasType_, + class TileCopy_, + class TileMmad_ +> +struct BlockMmad < + MmadAtlasA2PreloadAsyncFixpipe< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + >, + L1TileShape_, + L0TileShape_, + AType_, + BType_, + CType_, + BiasType_, + TileCopy_, + TileMmad_ +> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + >; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyGmToL1S = Gemm::Tile::CopyGmToL1>; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using CopyL0CToGm = typename std::conditional< + std::is_same_v, + Gemm::Tile::CopyL0CToGm, + typename TileCopy_::CopyL0CToGm + >::type; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert( + (std::is_same_v + ? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE + : (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE), + "L1TileShape exceeding the L1 space for the given data type" + ); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout( + L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout( + L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + syncGroupIdx = 0; + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + if constexpr (std::is_same_v) { + AscendC::WaitFlag(0); + } + } + + CATLASS_DEVICE + void operator()( + AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmBlockS, layout::VectorLayout const &layoutScale, + GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0 + ) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? + (startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = (kTileIdx < kTileCount - 1) ? + L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ? + (l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + l1TileMmadParams.flag = flag; + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.gmBlockS = gmBlockS; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.layoutScale = layoutScale; + l1TileMmadParams.syncLoopIdx = syncLoopIdx; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + + CATLASS_DEVICE + void Finalize(int32_t target, int32_t flag = 0) + { + for(;syncGroupIdx <= target; syncGroupIdx++) { + int32_t flagId = syncGroupIdx / 8 + flag; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } + } +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + AscendC::GlobalTensor gmBlockS; + LayoutC layoutCInGm; + layout::VectorLayout layoutScale; + int32_t syncLoopIdx; + int32_t flag; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + if constexpr (std::is_same_v) { + uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; + l1STensor = resource.l1Buf.template GetBufferByByte(l1SOffset); + AscendC::SetFlag(0); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ? + L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ? + L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ? + L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && + (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + if constexpr (std::is_same_v) { + auto layoutScale = params.layoutScale; + auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1))); + AscendC::WaitFlag(0); + copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + } + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + if constexpr (std::is_same_v) { + copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + } + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + if constexpr (std::is_same_v) { + copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + if constexpr (std::is_same_v) { + AscendC::SetFlag(0); + } + Finalize(params.syncLoopIdx, params.flag); + } + } + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + AscendC::LocalTensor l1STensor; + int32_t syncGroupIdx; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyGmToL1S copyGmToL1S; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp new file mode 100644 index 00000000..12262c68 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/const_args.hpp @@ -0,0 +1,9 @@ + +#ifndef CONST_ARGS_HPP +#define CONST_ARGS_HPP +constexpr static uint64_t MB_SIZE = 1024 * 1024UL; +constexpr static int32_t NUMS_PER_FLAG = 16; +constexpr static int32_t CACHE_LINE = 512; +constexpr static int32_t RESET_VAL = 0xffff; +constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_gm_to_l1_custom.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_gm_to_l1_custom.hpp new file mode 100644 index 00000000..84789073 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_gm_to_l1_custom.hpp @@ -0,0 +1,40 @@ +#ifndef COPY_GM_TO_L1_CUSTOM_HPP +#define COPY_GM_TO_L1_CUSTOM_HPP + +namespace Catlass::Gemm::Tile { + /// Partial specialization for nZ in and nZ out. + template < + class ArchTag, + class Element + > + struct CopyGmToL1> { + using LayoutDst = layout::VectorLayout; + using LayoutSrc = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4 + + // Mehtods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()( + AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = 1; + uint32_t blockLen = CeilDiv(layoutSrc.shape(0)); + + AscendC::DataCopyParams repeatParams; + + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } + }; +} +#endif // COPY_GM_TO_L1_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_l0c_to_gm_custom.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_l0c_to_gm_custom.hpp new file mode 100644 index 00000000..ba477984 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/copy_l0c_to_gm_custom.hpp @@ -0,0 +1,47 @@ +#ifndef COPY_L0C_TO_GM_CUSTOM_HPP +#define COPY_L0C_TO_GM_CUSTOM_HPP + +namespace Catlass::Gemm::Tile { + template < + class ElementAccumulator_, + class ElementDst_, + bool ReluEnable_ + > + struct CopyL0CToGm, + ScaleGranularity::PER_CHANNEL, + ReluEnable_> + { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Catlass::layout::zN; + using LayoutDst = Catlass::layout::RowMajor; + static constexpr auto quantPre = CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, AscendC::LocalTensor cbufWorkspace, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, cbufWorkspace, intriParams); + } + }; +} +#endif // COPY_L0C_TO_GM_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/dispatch_policy_custom.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/dispatch_policy_custom.hpp new file mode 100644 index 00000000..2ae65fa3 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/dispatch_policy_custom.hpp @@ -0,0 +1,53 @@ +#ifndef DISPATH_POLICY_CUSTOM_HPP +#define DISPATH_POLICY_CUSTOM_HPP + +namespace Catlass::Gemm { + template + struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; + }; + + template + struct MmadAtlasA2PreloadAsyncFixpipe : + public MmadAtlasA2PreloadAsync< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + > { + }; +} + +namespace Catlass::Epilogue { + + template + struct EpilogueAtlasA2UnQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; + + template + struct EpilogueAtlasA2PerTokenDequantQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; + + template + struct EpilogueAtlasA2PerTokenDequantSwigluQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; + + template + struct EpilogueAtlasA2PerTokenDequantV2 { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; +} +#endif // DISPATH_POLICY_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/get_tensor_addr.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/get_tensor_addr.hpp new file mode 100644 index 00000000..67b32c25 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/get_tensor_addr.hpp @@ -0,0 +1,16 @@ +#ifndef GET_TENSOR_ADDR_HPP +#define GET_TENSOR_ADDR_HPP +#include "kernel_operator.h" + +#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ + +template +FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address. + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} + +#endif // GET_TENSOR_ADDR_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp new file mode 100644 index 00000000..ec88e8fd --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp @@ -0,0 +1,195 @@ +#ifndef SYNC_UTIL_HPP +#define SYNC_UTIL_HPP + + +#include "kernel_operator.h" +#include "const_args.hpp" + +#include "moe_distribute_base.h" + +#ifndef HCCL_COMM +#include "shmem_api.h" +using namespace AscendC::HcclContextDef; +#endif + +#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ +constexpr int32_t MAX_RANK_SIZE = 32; +constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE; + +FORCE_INLINE_AICORE void AicSyncAll() { + AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8); + AscendC::CrossCoreWaitFlag<0x0>(8); +} + +template +FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) { + *((__gm__ T *)addr) = val; +} + +template +FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) { + return *((__gm__ T *)cache); +} + +FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { + using namespace AscendC; + GlobalTensor global; + global.SetGlobalBuffer(addr); + + // Important: add hint to avoid dcci being optimized by compiler + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); +} + +FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) { + do { + gm_dcci((__gm__ uint8_t *)sig_addr); + if (*sig_addr == cmp_val) { + return *sig_addr; + } + if (*sig_addr == cmp_val + 1) { + return *sig_addr; + } + } while (true); + return -1; +} + +FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { + do { + AscendC::LocalTensor ub; + ub.address_.logicPos = static_cast(TPosition::VECIN); + ub.address_.bufferAddr = 0; + AscendC::GlobalTensor sig; + sig.SetGlobalBuffer(sig_addr); + AscendC::DataCopy(ub, sig, 8); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if (ub(0) != cmp_val) { + return; + } + } while (true); + return; +} + + +class HcclShmem { +public: + #ifdef HCCL_COMM + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + Hccl hccl_; + GM_ADDR m_ptrArray[MAX_RANK_SIZE]; + FORCE_INLINE_AICORE + HcclShmem(){ + auto contextGM0 = AscendC::GetHcclContext(); + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; + + m_rank = WinContext_->localUsrRankId; + m_rankSize = WinContext_->rankSize; + m_segmentSize = WinContext_->winSize; + for (int i = 0; i < m_rankSize; i++) { + m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn); + } + } + #else + FORCE_INLINE_AICORE + HcclShmem(){ + m_segmentSize = SHMEM_MEM; + } + FORCE_INLINE_AICORE + void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) { + symmetricPtr = symmetricPtr_; + m_rank = rank; + m_rankSize = rankSize; + } + #endif + + FORCE_INLINE_AICORE + GM_ADDR operator() () const { + #ifdef HCCL_COMM + return m_ptrArray[m_rank]; + #else + return reinterpret_cast(shmem_ptr(symmetricPtr, m_rank)); + #endif + } + + FORCE_INLINE_AICORE + GM_ADDR operator() (int32_t index) const { + #ifdef HCCL_COMM + return m_ptrArray[index]; + #else + return reinterpret_cast(shmem_ptr(symmetricPtr, index)); + #endif + } + + FORCE_INLINE_AICORE + GM_ADDR operator () (int64_t offset, int32_t rankId) const { + #ifdef HCCL_COMM + if (offset < 0 || offset >= m_segmentSize) { + return nullptr; + } + if (rankId < 0 || rankId >= m_rankSize) { + return nullptr; + } + return m_ptrArray[rankId] + offset; + #else + return reinterpret_cast(shmem_ptr((symmetricPtr + offset), rankId)); + #endif + } + + + + FORCE_INLINE_AICORE + size_t SegmentSize() const { + return m_segmentSize; + } + + FORCE_INLINE_AICORE + int32_t RankSize() const { + return m_rankSize; + } + + + FORCE_INLINE_AICORE + ~HcclShmem() { + } + + + FORCE_INLINE_AICORE + void CrossRankSync() { + uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); + __gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset; + __gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048; + int count = gm_load(sync_base) + 1; + int vec_id = AscendC::GetBlockIdx(); + int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation(); + for(int i = vec_id; i < m_rankSize; i += vec_size) { + __gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16; + gm_store(sync_remote, count); + gm_dcci((__gm__ uint8_t*)sync_remote); + auto sync_check = sync_counter + i * 16; + gm_signal_wait_until_eq_for_barrier(sync_check, count); + } + + AscendC::SyncAll(); + gm_store(sync_base, count); + } + + FORCE_INLINE_AICORE + __gm__ int32_t* SyncBaseAddr() { + uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); + return (__gm__ int32_t*)(*this)() + flag_offset + 2048; + } + +private: + GM_ADDR symmetricPtr; + int32_t m_rank; + int32_t m_rankSize; + size_t m_segmentSize; +}; + + + + +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp new file mode 100644 index 00000000..7cc3a9c1 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp @@ -0,0 +1,20 @@ +#ifndef LAYOUT_3D_HPP +#define LAYOUT_3D_HPP +#include "kernel_operator.h" +#include "catlass/catlass.hpp" +class Layout3D { + int64_t strides[2]; + public: + CATLASS_DEVICE + Layout3D() {} + CATLASS_DEVICE + Layout3D(int64_t stride0, int64_t stride1) { + strides[0] = stride0; + strides[1] = stride1; + } + CATLASS_DEVICE + int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) { + return dim0 * strides[0] + dim1 * strides[1] + dim2; + } +}; +#endif // LAYOUT_3D_HPP diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/select_helper.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/select_helper.hpp new file mode 100644 index 00000000..574ab335 --- /dev/null +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/select_helper.hpp @@ -0,0 +1,25 @@ +#ifndef SELECT_HELPER_HPP +#define SELECT_HELPER_HPP + +#include "catlass/layout/layout.hpp" +using namespace AscendC; +using namespace Catlass; + +template +struct LayoutBInitializer { + CATLASS_DEVICE + static Layout create(uint32_t k, uint32_t n) { + return Layout{k, n}; + } +}; + +template +struct LayoutBInitializer> +> { + CATLASS_DEVICE + static Layout create(uint32_t k, uint32_t n) { + return Layout::template MakeLayout(k, n); + } +}; +#endif // SELECT_HELPER_HPP \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index fe06fbe5..2d46826a 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -738,7 +738,9 @@ at::Tensor& dispatch_ffn_combine( at::Tensor& out ) { char *group_ep_ptr = const_cast(group.data()); - EXEC_NPU_CMD(aclnnDispatchFFNCombine, + bool is_int8 = weight1[0].dtype() == at::kChar; + if (is_int8) { + EXEC_NPU_CMD(aclnnDispatchFFNCombine, x, weight1, weight2, @@ -749,6 +751,19 @@ at::Tensor& dispatch_ffn_combine( group_ep_ptr, max_output_size, out); + } else { + EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16, + x, + weight1, + weight2, + expert_idx, + scale1, + scale2, + probs, + group_ep_ptr, + max_output_size, + out); + } return out; } diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py new file mode 100644 index 00000000..5b50c07a --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py @@ -0,0 +1,234 @@ +import random + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu +from torch.distributed.distributed_c10d import _get_default_group + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +class TestDisptachFFNCombine: + + def __init__(self, rank, world_size, port): + self.rank = rank + self.world_size = world_size + self.master_ip = "127.0.0.1" + self.port = port + + def get_hcomm(self, comm_group): + hcomm_info = None + if torch.__version__ > "2.0.1": + hcomm_info = comm_group._get_backend( + torch.device("npu")).get_hccl_comm_name(self.rank) + else: + hcomm_info = comm_group.get_hccl_comm_name(self.rank) + return hcomm_info + + def setup_ep_tp( + self, + rank, + tp_size, + ep_size, + backend_type, + ep_ranks_list=None, + tp_ranks_list=None, + ): + for i in range(tp_size): + if ep_ranks_list: + ep_ranks = ep_ranks_list[i] + else: + ep_ranks = [x + ep_size * i for x in range(ep_size)] + ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks) + if rank in ep_ranks: + ep_group_tmp = ep_group + for i in range(ep_size): + if tp_ranks_list: + tp_ranks = tp_ranks_list[i] + else: + tp_ranks = [x * ep_size + i for x in range(tp_size)] + tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks) + if rank in tp_ranks: + tp_group_tmp = tp_group + return ep_group_tmp, tp_group_tmp + + def generate_hcom(self): + torch_npu.npu.set_device(self.rank) + dist.init_process_group( + backend="hccl", + rank=self.rank, + world_size=self.world_size, + init_method=f"tcp://127.0.0.1:{self.port}", + ) + + ep_size = 0 + tp_size = self.world_size + hcomm_info_dist = { + "default_pg_info": None, + "ep_hcomm_info": None, + "group_ep": None, + "tp_hcomm_info": None, + "group_tp": None, + } + if ep_size and tp_size: + group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size, + "hccl", None, None) + hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep) + hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp) + hcomm_info_dist["group_ep"] = group_ep + hcomm_info_dist["group_tp"] = group_tp + else: + if dist.is_available(): + default_pg = _get_default_group() + hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg) + hcomm_info = hcomm_info_dist["default_pg_info"] + self.hcomm_info = hcomm_info + + def run_tensor_list(self) -> bool: + torch_npu.npu.set_device(self.rank) + m = 64 + k = 1024 + n = 1024 + topk = 8 + e = 8 + k2 = n // 2 + n2 = k + + torch_npu.npu.config.allow_internal_format = True + x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + weight1 = self.generate_random_tensor((e, k, n), + dtype=torch.bfloat16).npu() + weight1 = torch_npu.npu_format_cast(weight1, 29) + weight2 = self.generate_random_tensor((e, k2, n2), + dtype=torch.bfloat16).npu() + weight2 = torch_npu.npu_format_cast(weight2, 29) + + expert_idx = torch.randint(0, + self.world_size * e, (m, topk), + dtype=torch.int32).npu() + scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu() + scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu() + probs = torch.randn(size=(m, topk), dtype=torch.float32).npu() + + weight1_nz_npu = [] + weight2_nz_npu = [] + scale1_npu = [] + scale2_npu = [] + for i in range(e): + weight1_nz_npu.append( + torch_npu.npu_format_cast(weight1[i].npu(), 29)) + scale1_npu.append(scale1[i].npu()) + weight2_nz_npu.append( + torch_npu.npu_format_cast(weight2[i].npu(), 29)) + scale2_npu.append(scale2[i].npu()) + + out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + + torch.ops._C_ascend.dispatch_ffn_combine( + x=x, + weight1=weight1_nz_npu, + weight2=weight2_nz_npu, + expert_idx=expert_idx, + scale1=scale1_npu, + scale2=scale2_npu, + probs=probs, + group=self.hcomm_info, + max_output_size=512, + out=out, + ) + return True + + def run_normal(self) -> bool: + torch_npu.npu.set_device(self.rank) + m = 64 + k = 1024 + n = 1024 + topk = 8 + e = 8 + k2 = n // 2 + n2 = k + + torch_npu.npu.config.allow_internal_format = True + x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + weight1 = self.generate_random_tensor((e, k, n), + dtype=torch.bfloat16).npu() + weight1 = torch_npu.npu_format_cast(weight1, 29) + weight2 = self.generate_random_tensor((e, k2, n2), + dtype=torch.bfloat16).npu() + weight2 = torch_npu.npu_format_cast(weight2, 29) + + expert_idx = torch.randint(0, + self.world_size * e, (m, topk), + dtype=torch.int32).npu() + scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu() + scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu() + probs = torch.randn(size=(m, topk), dtype=torch.float32).npu() + + weight1_nz_npu = [] + weight2_nz_npu = [] + scale1_npu = [] + scale2_npu = [] + weight1_nz_npu.append(torch_npu.npu_format_cast(weight1.npu(), 29)) + scale1_npu.append(scale1.npu()) + weight2_nz_npu.append(torch_npu.npu_format_cast(weight2.npu(), 29)) + scale2_npu.append(scale2.npu()) + + out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + + torch.ops._C_ascend.dispatch_ffn_combine( + x=x, + weight1=weight1_nz_npu, + weight2=weight2_nz_npu, + expert_idx=expert_idx, + scale1=scale1_npu, + scale2=scale2_npu, + probs=probs, + group=self.hcomm_info, + max_output_size=512, + out=out, + ) + return True + + def generate_random_tensor(self, size, dtype): + if dtype in [torch.float16, torch.bfloat16, torch.float32]: + return torch.randn(size=size, dtype=dtype) + elif dtype is torch.int8: + return torch.randint(-16, 16, size=size, dtype=dtype) + elif dtype is torch.int32: + return torch.randint(-1024, 1024, size=size, dtype=dtype) + else: + raise ValueError(f"Invalid dtype: {dtype}") + + +def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue): + op = TestDisptachFFNCombine(rank, world_size, port) + op.generate_hcom() + out1 = op.run_tensor_list() + q.put(out1) + out2 = op.run_normal() + q.put(out2) + + +@torch.inference_mode() +def test_dispatch_ffn_combine_kernel(): + world_size = 2 + mp.set_start_method("fork", force=True) + + q = mp.SimpleQueue() + p_list = [] + port = 29501 + random.randint(0, 10000) + + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, port, q)) + p.start() + p_list.append(p) + + results = [q.get() for _ in range(world_size)] + + for p in p_list: + p.join() + + assert all(results)