add dispatch_gmm_combine kernel (#3532)
### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
4
.gitmodules
vendored
Normal file
4
.gitmodules
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
[submodule "csrc/third_party/catlass"]
|
||||
path = csrc/third_party/catlass
|
||||
url = https://gitcode.com/cann/catlass.git
|
||||
branch = catlass-v1-stable
|
||||
@@ -3,6 +3,8 @@
|
||||
ROOT_DIR=$1
|
||||
SOC_VERSION=$2
|
||||
|
||||
git config --global --add safe.directory "$ROOT_DIR"
|
||||
|
||||
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
# ASCEND310P series
|
||||
# currently, no custom aclnn ops for ASCEND310 series
|
||||
@@ -11,11 +13,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
exit 0
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
# ASCEND910B (A2) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
|
||||
SOC_ARG="ascend910_93"
|
||||
else
|
||||
# others
|
||||
@@ -23,6 +25,30 @@ else
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git submodule init
|
||||
git submodule update
|
||||
|
||||
|
||||
# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h
|
||||
file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1)
|
||||
if [ -z "$file_path" ]; then
|
||||
echo "cannot find moe_distribute_base.h file in CANN env"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
|
||||
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
|
||||
TARGET_FILE="$TARGET_DIR/$(basename "$file_path")"
|
||||
|
||||
echo "*************************************"
|
||||
echo $file_path
|
||||
echo "$TARGET_DIR"
|
||||
cp "$file_path" "$TARGET_DIR"
|
||||
|
||||
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
|
||||
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
|
||||
|
||||
|
||||
# build custom ops
|
||||
cd csrc
|
||||
rm -rf build output
|
||||
|
||||
@@ -282,7 +282,7 @@ function(add_ops_src_copy)
|
||||
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
|
||||
add_custom_command(OUTPUT ${_BUILD_FLAG}
|
||||
COMMAND mkdir -p ${SRC_COPY_DST}
|
||||
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
|
||||
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST}
|
||||
COMMAND touch ${_BUILD_FLAG}
|
||||
)
|
||||
|
||||
|
||||
66
csrc/dispatch_ffn_combine/op_host/CMakeLists.txt
Normal file
66
csrc/dispatch_ffn_combine/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
set(_DISPATCH_FFN_INC_OPTS)
|
||||
if (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
|
||||
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
|
||||
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
|
||||
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
|
||||
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
|
||||
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
|
||||
endif()
|
||||
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
list(APPEND _DISPATCH_FFN_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
endif()
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME DispatchFFNCombine
|
||||
OPTIONS --cce-auto-sync=on
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-DHCCL_COMM
|
||||
${_DISPATCH_FFN_INC_OPTS}
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
dispatch_ffn_combine_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_dispatch_ffn_combine.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_dispatch_ffn_combine.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_dispatch_ffn_combine.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
dispatch_ffn_combine_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
dispatch_ffn_combine_proto.cpp
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_ffn_combine.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "aclnn_dispatch_ffn_combine.h"
|
||||
#include <algorithm>
|
||||
// #include "aclnn_kernels/common/op_error_check.h"
|
||||
// #include "opdev/op_log.h"
|
||||
// #include "opdev/common_types.h"
|
||||
// #include "opdev/platform.h"
|
||||
// #include "ophost/matmul_util.h"
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fcntl.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/file.h>
|
||||
#include <climits>
|
||||
#include "../op_host/error_log.h"
|
||||
// using namespace op;
|
||||
|
||||
// using namespace op;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static constexpr size_t TWO_DIMS = 2;
|
||||
static constexpr int64_t KVALUE_MIN = 256;
|
||||
static constexpr int64_t KVALUE_MAX = 65535;
|
||||
static constexpr size_t HCCL_GROUP_NAME_MAX = 128U;
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
bool transB, bool weightNz,
|
||||
const aclTensor* out,
|
||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||
extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize,
|
||||
aclOpExecutor *executor, aclrtStream stream);
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
|
||||
|
||||
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
const aclTensor* out,
|
||||
uint64_t* workspaceSize, aclOpExecutor** executor)
|
||||
{
|
||||
bool transB = false;
|
||||
bool weightNz = true;
|
||||
|
||||
aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
|
||||
maxOutputSize, transB, weightNz,
|
||||
out, workspaceSize, executor);
|
||||
return ret;
|
||||
}
|
||||
|
||||
aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
aclnnStatus ret = aclnnInnerDispatchFFNCombine(workspace, workspaceSize, executor, stream);
|
||||
return ret;
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef OP_API_INC_DISPATCH_FFN_COMBINE_
|
||||
#define OP_API_INC_DISPATCH_FFN_COMBINE_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "aclnn/aclnn_base.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* 算子功能:实现分布式MoE从InitRouting到Unpermute全部算子的融合
|
||||
* @brief aclnnDispatchFFNCombine的第一段接口,根据具体的计算流程,计算workspace大小。
|
||||
* @domain aclnn_ops_infer
|
||||
* @param [in] a: matmul左矩阵,数据类型支持:float16, bf16。
|
||||
* @param [in] b: matmul右矩阵,数据类型支持:float16, bf16。
|
||||
* @param [in] bias: 偏置,数据类型支持:float16, bf16。
|
||||
* @param [in] group: 标识通信域名称的字符串。
|
||||
* @param [in] worldsize: 通信域size,支持2/4/8卡。
|
||||
* @param [in] epRankId: ep本卡Id。取值范围[0, worldSize),各卡的rankId不能重复
|
||||
* @param [out] c: 计算+通信的结果,数据类型:同输入。
|
||||
* @param [out] workspaceSize: 返回需要在npu device侧申请的workspace大小。
|
||||
* @param [out] executor: 返回op执行器,包含了算子计算流程。
|
||||
* @return aclnnStatus: 返回状态码
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
const aclTensor* out,
|
||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||
|
||||
/**
|
||||
* @brief aclnnDispatchGmmCombine的第二段接口,用于执行计算。
|
||||
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
|
||||
* @param [in] workspace_size: 在npu device侧申请的workspace大小,由第一段接口aclnnDispatchFFNCombineGetWorkspaceSize获取。
|
||||
* @param [in] exector: op执行器,包含了算子计算流程。
|
||||
* @param [in] stream: acl stream流。
|
||||
* @return aclnnStatus: 返回状态码
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // OP_API_INC_GMM_ALLTOALLV_
|
||||
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dispatch_ffn_combine_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class DispatchFFNCombine : public OpDef {
|
||||
public:
|
||||
explicit DispatchFFNCombine(const char *name) : OpDef(name) {
|
||||
this->Input("a")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("w1")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.IgnoreContiguous();
|
||||
this->Input("w2")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.IgnoreContiguous();
|
||||
this->Input("expertIdx")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("scale1")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("scale2")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("probs")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
// 输出
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
|
||||
|
||||
this->Attr("group").AttrType(REQUIRED).String();
|
||||
this->Attr("M").AttrType(OPTIONAL).Int();
|
||||
this->Attr("transB").AttrType(OPTIONAL).Bool(false);
|
||||
this->Attr("weightNz").AttrType(OPTIONAL).Bool(false);
|
||||
|
||||
OpAICoreConfig aicore_config;
|
||||
aicore_config.DynamicCompileStaticFlag(true)
|
||||
.DynamicFormatFlag(true)
|
||||
.DynamicRankSupportFlag(true)
|
||||
.DynamicShapeSupportFlag(true)
|
||||
.NeedCheckSupportFlag(false)
|
||||
.PrecisionReduceFlag(true)
|
||||
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
|
||||
.ExtendCfgInfo("jitCompile.flag", "static_false")
|
||||
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
this->AICore().AddConfig("ascend910b", aicore_config);
|
||||
this->MC2().HcclGroup("group");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DispatchFFNCombine);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dispatch_ffn_proto.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include <register/op_impl_registry.h>
|
||||
// #include "../../common/ophost/op_util.h"
|
||||
// #include "../../common/ophost/hcom_topo_info.h"
|
||||
// #include "log/ops_log.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace ops {
|
||||
const size_t ATTR_GROUP = 0;
|
||||
const size_t ATTR_RANK_SIZE = 1;
|
||||
const size_t SUPPORT_DIM_SIZE = 2;
|
||||
|
||||
static ge::graphStatus InferShapeDispatchFFNCombine(gert::InferShapeContext* context) {
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataTypeDispatchFFNCombine(gert::InferDataTypeContext* context) {
|
||||
// auto d_type = context->GetInputDataType(0);
|
||||
// context->SetOutputDataType(0, d_type);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(DispatchFFNCombine)
|
||||
.InferShape(InferShapeDispatchFFNCombine)
|
||||
.InferDataType(InferDataTypeDispatchFFNCombine);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,265 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
/*!
|
||||
* \file dispatch_ffn_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "vector"
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "tiling/tiling_api.h"
|
||||
#include "error_log.h"
|
||||
#include "hcom_topo_info.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "dispatch_ffn_combine_tiling.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace ge;
|
||||
|
||||
namespace {
|
||||
// 1. 常量定义
|
||||
const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug";
|
||||
constexpr uint32_t ATTR_GROUP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_IS_TRANS_B = 2;
|
||||
constexpr uint32_t ATTR_WEIGHT_NZ = 3;
|
||||
constexpr uint64_t INIT_TILINGKEY = 1000000;
|
||||
constexpr uint64_t TILINGKEY_TRANS_B = 1U;
|
||||
constexpr uint64_t TILINGKEY_WEIGHT_NZ = 10;
|
||||
constexpr uint32_t X_INDEX = 0;
|
||||
constexpr uint32_t WEIGHT_INDEX = 1;
|
||||
constexpr uint32_t WEIGHT2_INDEX = 2;
|
||||
constexpr uint32_t EXPERTID_INDEX = 3;
|
||||
constexpr uint32_t BLOCK_NUM = 20;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
}
|
||||
|
||||
namespace optiling {
|
||||
|
||||
static int32_t CeilDev(int32_t num, int32_t div)
|
||||
{
|
||||
if (div == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (num + div - 1) / div;
|
||||
}
|
||||
|
||||
// 解析并校验 rankId, group, worldSize, isTransB 属性值
|
||||
static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
// todo:Attr相关tilingdata的设置、校验、打印
|
||||
auto groupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
|
||||
auto maxOutputSizePtr = attrs->GetAttrPointer<int>(ATTR_MAX_OUTPUT_SIZE_INDEX);
|
||||
auto is_trans_b = attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B);
|
||||
auto weight_nz = attrs->GetAttrPointer<bool>(ATTR_WEIGHT_NZ);
|
||||
OP_TILING_CHECK(groupPtr == nullptr || strlen(groupPtr) == 0,
|
||||
OP_LOGE(K_INNER_DEBUG, "group is invalid."), return GRAPH_FAILED);
|
||||
|
||||
OP_TILING_CHECK(is_trans_b == nullptr,
|
||||
OP_LOGE(K_INNER_DEBUG, "is_trans_b is invalid."), return GRAPH_FAILED);
|
||||
OP_TILING_CHECK(weight_nz == nullptr,
|
||||
OP_LOGE(K_INNER_DEBUG, "weight_nz is invalid."), return GRAPH_FAILED);
|
||||
|
||||
info.maxOutputSize = *maxOutputSizePtr;
|
||||
info.isTransposeB = *is_trans_b;
|
||||
info.isWeightNz = *weight_nz;
|
||||
|
||||
int64_t rankSize;
|
||||
(void)ge::HcomTopoInfo::Instance().GetGroupRankSize(groupPtr, rankSize);
|
||||
info.worldSize = rankSize;
|
||||
|
||||
OP_LOGD(K_INNER_DEBUG, "maxOutputSize=%d ", info.maxOutputSize);
|
||||
OP_LOGD(K_INNER_DEBUG, "rankSize=%d ", info.worldSize);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 提取输入张量 A 和 B 的形状,计算出 M、K、N 值
|
||||
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");
|
||||
|
||||
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
|
||||
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
|
||||
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
|
||||
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
|
||||
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
|
||||
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
|
||||
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
|
||||
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);
|
||||
|
||||
info.M = M;
|
||||
info.N = N;
|
||||
info.K = K;
|
||||
info.expertPerRank = expertPerRank;
|
||||
info.topK = topK;
|
||||
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
|
||||
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
|
||||
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
|
||||
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
|
||||
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 获取当前芯片平台的 AI Core 数目、UB 容量等硬件信息。
|
||||
static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||
{
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
uint64_t ubSize = 0U;
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
info.aivNum = aivNum;
|
||||
info.totalUbSize = ubSize;
|
||||
|
||||
OP_LOGD(K_INNER_DEBUG, "aivNum=%d", info.aivNum);
|
||||
OP_LOGD(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineInfo &info)
|
||||
{
|
||||
cocTilingData.m0 = 128;
|
||||
cocTilingData.k0 = 256;
|
||||
cocTilingData.n0 = 256;
|
||||
cocTilingData.swizzleDirect = 1;
|
||||
cocTilingData.swizzleOffset = 7;
|
||||
cocTilingData.ubMoveNum = 16 * 1024;
|
||||
cocTilingData.pValue = 1;
|
||||
cocTilingData.commNpuSplit = info.worldSize;
|
||||
cocTilingData.commDataSplit = 1;
|
||||
cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2;
|
||||
}
|
||||
|
||||
// 主调度函数:
|
||||
// 获取 tilingData ➝ 检查 Attr ➝ 检查 Shape ➝ 获取平台信息
|
||||
// ➝ 调用 SetTilingData(根据rank数目) ➝ 设置 blockDim ➝ 设置 tilingKey ➝ 设置 workspace ➝ 配置通信参数
|
||||
|
||||
static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OP_LOGI(nodeName, "Enter DispatchFFNCombine tiling func.");
|
||||
|
||||
// 1. tilingData
|
||||
DispatchFFNCombineTilingData *tilingData = context->GetTilingData<DispatchFFNCombineTilingData>();
|
||||
OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_LOGI(nodeName, "DispatchFFNCombine get tilingData.");
|
||||
DispatchFFNCombineInfo& info = tilingData->dispatchFFNCombineInfo;
|
||||
OP_LOGI(nodeName, "DispatchFFNCombine get tilingData info.");
|
||||
|
||||
OP_TILING_CHECK(DispatchFFNCombineCheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
|
||||
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckAttrAndSetTiling Failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_TILING_CHECK(DispatchFFNCombineCheckShapeAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
|
||||
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckShapeAndSetTiling Failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
OP_TILING_CHECK(DispatchFFNCombineGetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
|
||||
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine GetPlatformInfoAndSetTiling Failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
SetTilingData(tilingData->cocTiling, info);
|
||||
|
||||
// 2. set blockDim
|
||||
uint32_t blockDim = 1U;
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
auto aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
auto aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
|
||||
context->SetBlockDim(blockDim);
|
||||
|
||||
// 3. set tilingKey
|
||||
uint64_t tilingKey = INIT_TILINGKEY;
|
||||
tilingKey += info.isTransposeB ? TILINGKEY_TRANS_B : 0;
|
||||
tilingKey += info.isWeightNz ? TILINGKEY_WEIGHT_NZ : 0;
|
||||
context->SetTilingKey(tilingKey);
|
||||
|
||||
OP_LOGD(K_INNER_DEBUG, "tilingKey=%d", tilingKey);
|
||||
|
||||
optiling::MoeInitRoutingQuantV2TilingBase moeInitRoutingQuantV2TilingBase;
|
||||
int64_t inuptXDtypeSize = sizeof(int16_t);
|
||||
int64_t scaleDim0 = 0;
|
||||
int64_t ubSize = 196352;
|
||||
int64_t expertCapacity = 0;
|
||||
int64_t expertNum = info.expertPerRank * info.worldSize;
|
||||
int64_t activeNum = 0;
|
||||
int64_t dropPadMode = 0;
|
||||
int64_t expertTokensCountOrCumsumFlag = 2;
|
||||
bool expertTokensBeforeCapacityFlag = false;
|
||||
int64_t quantMode = 1;
|
||||
uint32_t aivNumInitRouting = 2 * BLOCK_NUM;
|
||||
moeInitRoutingQuantV2TilingBase.DoTiling(info.M, info.K, info.topK, expertCapacity, expertNum, activeNum, dropPadMode,
|
||||
expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0, aivNumInitRouting, ubSize);
|
||||
uint64_t initRoutingQuantTilingKey = moeInitRoutingQuantV2TilingBase.tilingKey_;
|
||||
size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_;
|
||||
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData = moeInitRoutingQuantV2TilingBase.quantTilingData;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vbsComputeParamsOp;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vmsMiddleComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vmsMiddleComputeParamsOp;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.sortOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.sortOutComputeParamsOp;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstComputeParamsOp;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstCapacityComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstCapacityComputeParamsOp;
|
||||
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.gatherOutComputeParamsOp;
|
||||
tilingData->cocTiling.initRoutingQuantTilingKey = initRoutingQuantTilingKey;
|
||||
|
||||
// 4. workspace
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
uint32_t n2 = info.K;
|
||||
uint32_t k2 = info.N / 2;
|
||||
|
||||
uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) +
|
||||
info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 +
|
||||
info.maxOutputSize * sizeof(float) * 2 +
|
||||
std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
|
||||
std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
|
||||
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace);
|
||||
|
||||
|
||||
// 5. communication
|
||||
auto attrs = context->GetAttrs();
|
||||
auto group = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
|
||||
uint32_t opType = 8U;
|
||||
std::string algConfig = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig);
|
||||
mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling);
|
||||
|
||||
OP_LOGI(nodeName, "Leave DispatchFFNCombine tiling func.");
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchFFNCombineTilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
return DispatchFFNCombineTilingFuncImpl(context);
|
||||
}
|
||||
|
||||
struct DispatchFFNCombineCompileInfo {};
|
||||
ge::graphStatus TilingParseForDispatchFFNCombine(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(DispatchFFNCombine)
|
||||
.Tiling(DispatchFFNCombineTilingFunc)
|
||||
.TilingParse<DispatchFFNCombineCompileInfo>(TilingParseForDispatchFFNCombine);
|
||||
} // namespace optiling
|
||||
47
csrc/dispatch_ffn_combine/op_host/error_log.h
Normal file
47
csrc/dispatch_ffn_combine/op_host/error_log.h
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname)); \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname)); \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname)); \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define OP_TILING_CHECK(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
72
csrc/dispatch_ffn_combine/op_host/hcom_topo_info.h
Normal file
72
csrc/dispatch_ffn_combine/op_host/hcom_topo_info.h
Normal file
@@ -0,0 +1,72 @@
|
||||
/* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
* ===================================================================================================================*/
|
||||
|
||||
#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_
|
||||
#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
|
||||
using Status = int32_t;
|
||||
|
||||
namespace ge {
|
||||
static constexpr uint32_t COMM_MESH = 0b1U;
|
||||
static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U);
|
||||
static constexpr uint32_t COMM_RING = (COMM_MESH << 2U);
|
||||
static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U);
|
||||
class HcomTopoInfo {
|
||||
public:
|
||||
enum class TopoLevel {
|
||||
L0 = 0,
|
||||
L1,
|
||||
MAX,
|
||||
};
|
||||
struct TopoLevelDesc {
|
||||
uint32_t comm_sets;
|
||||
uint32_t rank_size;
|
||||
};
|
||||
using TopoDescs = TopoLevelDesc[static_cast<int32_t>(TopoLevel::MAX)];
|
||||
struct TopoInfo {
|
||||
int64_t rank_size;
|
||||
void *notify_handle;
|
||||
TopoDescs topo_level_descs;
|
||||
};
|
||||
static HcomTopoInfo &Instance();
|
||||
bool TopoInfoHasBeenSet(const char_t *group);
|
||||
bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info);
|
||||
Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info);
|
||||
Status GetGroupRankSize(const char_t *group, int64_t &rank_size);
|
||||
TopoDescs *GetGroupTopoDesc(const char_t *group);
|
||||
Status GetGroupNotifyHandle(const char_t *group, void *¬ify_handle);
|
||||
void UnsetGroupTopoInfo(const char_t *group) {
|
||||
const std::lock_guard<std::mutex> lock(mutex_);
|
||||
(void) rank_info_.erase(group);
|
||||
}
|
||||
|
||||
Status SetGroupOrderedStream(const char_t *group, void *stream);
|
||||
Status GetGroupOrderedStream(const char_t *group, void *&stream);
|
||||
void UnsetGroupOrderedStream(const char_t *group) {
|
||||
const std::lock_guard<std::mutex> lock(mutex_);
|
||||
(void) group_to_ordered_stream_.erase(group);
|
||||
};
|
||||
|
||||
Status SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream);
|
||||
Status GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream);
|
||||
void UnsetGroupOrderedStream(const int32_t device_id, const char_t *group);
|
||||
private:
|
||||
HcomTopoInfo() = default;
|
||||
~HcomTopoInfo() = default;
|
||||
std::unordered_map<std::string, TopoInfo> rank_info_;
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<std::string, void*> group_to_ordered_stream_; // 通信域保序流
|
||||
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // 通信域保序流
|
||||
};
|
||||
}
|
||||
|
||||
#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_
|
||||
9
csrc/dispatch_ffn_combine/op_host/tiling_args.h
Normal file
9
csrc/dispatch_ffn_combine/op_host/tiling_args.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TILING_ARGS_H
|
||||
#define TILING_ARGS_H
|
||||
#include <cstdint>
|
||||
|
||||
namespace Moe {
|
||||
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL;
|
||||
constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL;
|
||||
} // namespace Moe
|
||||
#endif // TILING_ARGS_H
|
||||
51
csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp
Normal file
51
csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file dispatch_ffn_combine.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "kernel_operator.h"
|
||||
#include "lib/matmul_intf.h"
|
||||
#include "dispatch_ffn_combine_tiling.h"
|
||||
#include "dispatch_ffn_combine.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace DispatchFFNCombineImpl;
|
||||
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
||||
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||
if (TILING_KEY_IS(1000000)) {
|
||||
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000001)) {
|
||||
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000010)) {
|
||||
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000011)) {
|
||||
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
276
csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h
Normal file
276
csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h
Normal file
@@ -0,0 +1,276 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dispatch_ffn_combine.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef DISPATCH_FFN_COMBINE_H
|
||||
#define DISPATCH_FFN_COMBINE_H
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "utils/moe_distribute_base.h"
|
||||
|
||||
#include "dispatch_ffn_combine_tiling.h"
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/arch.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/epilogue/block/block_epilogue.hpp"
|
||||
#include "catlass/epilogue/tile/tile_copy.hpp"
|
||||
#include "catlass/epilogue/tile/tile_elemwise_add.hpp"
|
||||
#include "catlass/epilogue/tile/tile_elemwise_muls.hpp"
|
||||
#include "catlass/gemm/block/block_mmad.hpp"
|
||||
#include "catlass/gemm/block/block_swizzle.hpp"
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
#include "catlass/gemm/kernel/matmul_epilogue.hpp"
|
||||
#include "catlass/gemm/gemm_type.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
|
||||
#include "utils/select_helper.hpp"
|
||||
#include "utils/const_args.hpp"
|
||||
#include "dispatch_ffn_combine_kernel.hpp"
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
||||
|
||||
using namespace Catlass;
|
||||
|
||||
namespace DispatchFFNCombineImpl {
|
||||
#define TemplateMMA2AClass typename AType_, typename BType_, typename CType_, bool TB_, bool Nz_
|
||||
#define TemplateMMA2ACFunc AType_, BType_, CType_, TB_, Nz_
|
||||
|
||||
using namespace AscendC;
|
||||
template <TemplateMMA2AClass>
|
||||
class DispatchFFNCombine {
|
||||
public:
|
||||
__aicore__ inline DispatchFFNCombine() {};
|
||||
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
|
||||
private:
|
||||
GM_ADDR xGM_;
|
||||
GM_ADDR weight1GM_;
|
||||
GM_ADDR weight2GM_;
|
||||
GM_ADDR expertIdGM_;
|
||||
GM_ADDR scale1GM_;
|
||||
GM_ADDR scale2GM_;
|
||||
GM_ADDR probs_;
|
||||
GM_ADDR outGM_;
|
||||
GM_ADDR workspaceGM_;
|
||||
|
||||
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
||||
GM_ADDR moeInitRoutingQuantV2Offset = nullptr;
|
||||
GM_ADDR expertTokensBeforeCapacity = nullptr;
|
||||
|
||||
|
||||
TBuf<AscendC::TPosition::VECCALC> uBuf_;
|
||||
|
||||
int32_t rank;
|
||||
int32_t rankSize;
|
||||
int32_t aivNum;
|
||||
|
||||
int32_t m0;
|
||||
int32_t k0;
|
||||
int32_t n0;
|
||||
int32_t swizzlOffset;
|
||||
int32_t swizzlDirect;
|
||||
int32_t ubMoveNum;
|
||||
int32_t pValue;
|
||||
|
||||
int32_t commNpuSplit;
|
||||
int32_t commDataSplit;
|
||||
int32_t lenPerLoop;
|
||||
|
||||
int32_t m;
|
||||
int32_t k;
|
||||
int32_t n;
|
||||
int32_t topK;
|
||||
int32_t expertPerRank;
|
||||
int32_t maxOutputSize;
|
||||
int32_t EP;
|
||||
|
||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
||||
uint64_t initRoutingQuantTilingKey;
|
||||
|
||||
// Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <TemplateMMA2AClass>
|
||||
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
|
||||
GET_TILING_DATA(tilingData, tilingGM);
|
||||
|
||||
xGM_ = xGM;
|
||||
weight1GM_ = weight1GM;
|
||||
weight2GM_ = weight2GM;
|
||||
expertIdGM_ = expertIdGM;
|
||||
scale1GM_ = scale1GM;
|
||||
scale2GM_ = scale2GM;
|
||||
probs_ = probs;
|
||||
|
||||
outGM_ = outGM;
|
||||
|
||||
workspaceGM_ = workspaceGM;
|
||||
|
||||
aivNum = tilingData.dispatchFFNCombineInfo.aivNum;
|
||||
|
||||
m = tilingData.dispatchFFNCombineInfo.M;
|
||||
k = tilingData.dispatchFFNCombineInfo.K;
|
||||
n = tilingData.dispatchFFNCombineInfo.N;
|
||||
EP = tilingData.dispatchFFNCombineInfo.worldSize;
|
||||
topK = tilingData.dispatchFFNCombineInfo.topK;
|
||||
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
|
||||
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
|
||||
|
||||
m0 = tilingData.cocTiling.m0;
|
||||
k0 = tilingData.cocTiling.k0;
|
||||
n0 = tilingData.cocTiling.n0;
|
||||
swizzlDirect = tilingData.cocTiling.swizzleDirect;
|
||||
swizzlOffset = tilingData.cocTiling.swizzleOffset;
|
||||
ubMoveNum = tilingData.cocTiling.ubMoveNum;
|
||||
pValue = tilingData.cocTiling.pValue;
|
||||
commNpuSplit = tilingData.cocTiling.commNpuSplit;
|
||||
commDataSplit = tilingData.cocTiling.commDataSplit;
|
||||
lenPerLoop = tilingData.cocTiling.lenPerLoop;
|
||||
moeInitRoutingQuantV2TilingData = tilingData.cocTiling.moeInitRoutingQuantV2TilingData;
|
||||
initRoutingQuantTilingKey = tilingData.cocTiling.initRoutingQuantTilingKey;
|
||||
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
|
||||
|
||||
rank = WinContext_->localUsrRankId;
|
||||
rankSize = WinContext_->rankSize;
|
||||
}
|
||||
|
||||
template <TemplateMMA2AClass>
|
||||
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
||||
{
|
||||
// Define ArchTag
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
constexpr bool enableUnitFlag = false;
|
||||
constexpr bool enableShuffleK = true;
|
||||
|
||||
uint32_t k2 = n/2;
|
||||
uint32_t n2 = k;
|
||||
|
||||
int64_t activeNum = 0;
|
||||
int64_t expertCapacity = 0;
|
||||
int64_t expertNum = expertPerRank * EP;
|
||||
int64_t dropPadMode = 0;
|
||||
int64_t expertTokensCountOrCumsumFlag = 2;
|
||||
bool expertTokensBeforeCapacityFlag = false;
|
||||
int64_t quantMode = 1;
|
||||
|
||||
using LayoutA = layout::RowMajor;
|
||||
using LayoutB = typename std::conditional<
|
||||
Nz_,
|
||||
layout::zN,
|
||||
typename std::conditional<TB_, layout::ColumnMajor, layout::RowMajor>::type
|
||||
>::type;
|
||||
|
||||
LayoutB layoutB1 = LayoutBInitializer<LayoutB, BType_>::create(k, n);
|
||||
LayoutB layoutB2 = LayoutBInitializer<LayoutB, BType_>::create(k2, n2);
|
||||
using LayoutC = layout::RowMajor;
|
||||
using L1TileShape = GemmShape<128, 256, 512>; // M, N, K
|
||||
|
||||
constexpr uint32_t workspaceStages = 2;
|
||||
constexpr uint32_t preloadStages = 1;
|
||||
constexpr uint32_t l1Stages = 2;
|
||||
constexpr uint32_t l0AStages = 2;
|
||||
constexpr uint32_t l0BStages = 2;
|
||||
constexpr uint32_t l0CStages = 1;
|
||||
|
||||
using DispatchPolicy = Gemm::MmadAtlasA2PreloadAsyncFixpipe<
|
||||
preloadStages,
|
||||
l1Stages, l0AStages, l0BStages, l0CStages,
|
||||
enableUnitFlag, enableShuffleK
|
||||
>;
|
||||
|
||||
using L0TileShape = GemmShape<128, 256, 128>;
|
||||
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||
using BType = Gemm::GemmType<int8_t, LayoutB>;
|
||||
using CType = Gemm::GemmType<float16_t, layout::RowMajor>;
|
||||
using D1Type = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||
|
||||
using D2Type = typename std::conditional<
|
||||
std::is_same_v<CType_, bfloat16_t>,
|
||||
Gemm::GemmType<bfloat16_t, layout::RowMajor>,
|
||||
Gemm::GemmType<CType_, layout::RowMajor>
|
||||
>::type;
|
||||
|
||||
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
||||
constexpr uint32_t ubStages = 2;
|
||||
|
||||
using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant<ubStages>;
|
||||
|
||||
using ScaleType = Gemm::GemmType<uint64_t, layout::VectorLayout>;
|
||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using ElementMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
using TileElemWiseMuls = Epilogue::Tile::TileElemWiseMuls<ArchTag, ElementMulType, 0>;
|
||||
|
||||
using TileCopy1 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D1Type>;
|
||||
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
|
||||
D1Type, TileElemWiseMuls, TileCopy1>;
|
||||
|
||||
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
|
||||
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
|
||||
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
|
||||
D2Type, TileCopy2>;
|
||||
|
||||
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<9, 1>;
|
||||
using ElementGroupList = int64_t;
|
||||
using MatmulKernel = Gemm::Kernel::DispatchFFNCombineKernel<BlockMmad,
|
||||
BlockScheduler, ElementGroupList, BlockEpilogue1, BlockEpilogue2>;
|
||||
|
||||
LayoutA layoutA1{static_cast<uint32_t>(m), static_cast<uint32_t>(k)};
|
||||
LayoutA layoutA2{static_cast<uint32_t>(m), static_cast<uint32_t>(k2)};
|
||||
layout::VectorLayout layoutScale1{static_cast<uint32_t>(n)};
|
||||
layout::VectorLayout layoutScale2{static_cast<uint32_t>(n2)};
|
||||
layout::RowMajor layoutD1{static_cast<uint32_t>(maxOutputSize), static_cast<uint32_t>(k2)};
|
||||
layout::RowMajor layoutD2{static_cast<uint32_t>(m*topK), static_cast<uint32_t>(n2)};
|
||||
// Prepare params
|
||||
|
||||
GemmCoord problemShape{static_cast<uint32_t>(m), static_cast<uint32_t>(n), static_cast<uint32_t>(k)};
|
||||
|
||||
uint32_t epilogueCoreNum = aivNum / 2;
|
||||
uint32_t epilogueGranularity = expertPerRank - 1;
|
||||
|
||||
typename MatmulKernel::Params params{
|
||||
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
|
||||
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
|
||||
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
|
||||
epilogueCoreNum, epilogueGranularity,
|
||||
xGM_, layoutA1, layoutA2,
|
||||
weight1GM_, layoutB1,
|
||||
weight2GM_, layoutB2,
|
||||
scale1GM_, layoutScale1,
|
||||
scale2GM_, layoutScale2,
|
||||
outGM_, layoutD1, layoutD2,
|
||||
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
||||
expertTokensBeforeCapacity, probs_,
|
||||
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
||||
//Call kernel
|
||||
MatmulKernel kernel(params);
|
||||
kernel(params);
|
||||
}
|
||||
|
||||
} // DispatchFFNCombineImpl
|
||||
#endif // DISPATCH_FFN_COMBINE_H
|
||||
@@ -0,0 +1,814 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef DISPATH_FFN_COMBINE_KERNEL_HPP
|
||||
#define DISPATH_FFN_COMBINE_KERNEL_HPP
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/cross_core_sync.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/coord.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/epilogue/tile/tile_copy.hpp"
|
||||
|
||||
#include "utils/block_mmad_preload_async_fixpipe_quant.hpp"
|
||||
#include "utils/copy_gm_to_l1_custom.hpp"
|
||||
#include "utils/copy_l0c_to_gm_custom.hpp"
|
||||
#include "utils/block_epilogue_pertoken_row.hpp"
|
||||
#include "utils/block_epilogue_pertoken_swiglu.hpp"
|
||||
#include "utils/hccl_shmem.hpp"
|
||||
#include "utils/const_args.hpp"
|
||||
#include "utils/layout3d.hpp"
|
||||
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
|
||||
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
|
||||
#include "unpermute/moe_token_unpermute.h"
|
||||
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
namespace Catlass::Gemm::Kernel {
|
||||
|
||||
template <
|
||||
class BlockMmad_,
|
||||
class BlockScheduler_,
|
||||
class ElementGroupList_,
|
||||
class BlockEpilogue1_,
|
||||
class BlockEpilogue2_
|
||||
>
|
||||
class DispatchFFNCombineKernel {
|
||||
public:
|
||||
using BlockMmad = BlockMmad_;
|
||||
using ArchTag = typename BlockMmad::ArchTag;
|
||||
using L1TileShape = typename BlockMmad::L1TileShape;
|
||||
using ElementA = typename BlockMmad::ElementA;
|
||||
using LayoutA = typename BlockMmad::LayoutA;
|
||||
using ElementB = typename BlockMmad::ElementB;
|
||||
using LayoutB = typename BlockMmad::LayoutB;
|
||||
using ElementC = typename BlockMmad::ElementC;
|
||||
using LayoutC = typename BlockMmad::LayoutC;
|
||||
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
||||
using ElementScale = uint64_t;
|
||||
using LayoutScale = typename layout::VectorLayout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = typename layout::VectorLayout;
|
||||
using BlockScheduler = BlockScheduler_;
|
||||
using BlockEpilogue1 = BlockEpilogue1_;
|
||||
using BlockEpilogue2 = BlockEpilogue2_;
|
||||
using ElementD1 = typename BlockEpilogue1::ElementD;
|
||||
using LayoutD1 = typename BlockEpilogue1::LayoutD;
|
||||
using ElementD2 = typename BlockEpilogue2::ElementD;
|
||||
using LayoutD2 = typename BlockEpilogue2::LayoutD;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
// Data members
|
||||
GemmCoord problemShape;
|
||||
__gm__ ElementA *ptrA;
|
||||
LayoutA layoutA;
|
||||
LayoutA layoutA2;
|
||||
__gm__ ElementB *ptrB1;
|
||||
LayoutB layoutB1;
|
||||
__gm__ ElementB *ptrB2;
|
||||
LayoutB layoutB2;
|
||||
__gm__ ElementScale *ptrScale1;
|
||||
LayoutScale layoutScale1;
|
||||
__gm__ ElementScale *ptrScale2;
|
||||
LayoutScale layoutScale2;
|
||||
__gm__ ElementD2 *ptrOutput;
|
||||
LayoutD1 layoutD1;
|
||||
LayoutD2 layoutD2;
|
||||
GM_ADDR ptrWorkspace;
|
||||
int32_t EP;
|
||||
int32_t expertPerRank;
|
||||
uint32_t maxOutputSize;
|
||||
uint32_t rank;
|
||||
uint32_t rankSize;
|
||||
int32_t ubMoveNum;
|
||||
//--------------
|
||||
GM_ADDR expertIdx;
|
||||
GM_ADDR moeInitRoutingQuantV2Scale;
|
||||
GM_ADDR moeInitRoutingQuantV2Offset;
|
||||
GM_ADDR expandedX;
|
||||
GM_ADDR expandedRowIdx;
|
||||
GM_ADDR expertTokensCountOrCumsum;
|
||||
GM_ADDR expertTokensBeforeCapacity;
|
||||
GM_ADDR dynamicQuantScale;
|
||||
GM_ADDR probs;
|
||||
int64_t topK;
|
||||
uint64_t initRoutingQuantTilingKey;
|
||||
uint32_t epilogueCoreNum;
|
||||
uint32_t epilogueGranularity;
|
||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
||||
//--------------
|
||||
|
||||
// Methods
|
||||
CATLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CATLASS_HOST_DEVICE
|
||||
Params(
|
||||
GemmCoord problemShape_,
|
||||
uint32_t EP_, uint32_t expertPerRank_, uint32_t maxOutputSize_,
|
||||
uint32_t rank_, uint32_t rankSize_, int64_t topK_,
|
||||
uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_,
|
||||
GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_,
|
||||
GM_ADDR ptrB1_, LayoutB layoutB1_,
|
||||
GM_ADDR ptrB2_, LayoutB layoutB2_,
|
||||
GM_ADDR ptrScale1_, LayoutScale layoutScale1_,
|
||||
GM_ADDR ptrScale2_, LayoutScale layoutScale2_,
|
||||
GM_ADDR ptrOutput_, LayoutD2 layoutD1_, LayoutD2 layoutD2_,
|
||||
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
||||
GM_ADDR moeInitRoutingQuantV2Offset_,
|
||||
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
||||
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
||||
) : problemShape(problemShape_),
|
||||
EP(EP_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
||||
rank(rank_), rankSize(rankSize_), topK(topK_),
|
||||
initRoutingQuantTilingKey(initRoutingQuantTilingKey_),
|
||||
epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_),
|
||||
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_),
|
||||
ptrB1(reinterpret_cast<__gm__ ElementB *>(ptrB1_)), layoutB1(layoutB1_),
|
||||
ptrB2(reinterpret_cast<__gm__ ElementB *>(ptrB2_)), layoutB2(layoutB2_),
|
||||
ptrScale1(reinterpret_cast<__gm__ ElementScale *>(ptrScale1_)), layoutScale1(layoutScale1_),
|
||||
ptrScale2(reinterpret_cast<__gm__ ElementScale *>(ptrScale2_)), layoutScale2(layoutScale2_),
|
||||
ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_),
|
||||
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
||||
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
||||
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
||||
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),
|
||||
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Methods
|
||||
CATLASS_DEVICE
|
||||
DispatchFFNCombineKernel(Params const ¶ms)
|
||||
{
|
||||
if ASCEND_IS_AIC {
|
||||
coreIdx = AscendC::GetBlockIdx();
|
||||
coreNum = AscendC::GetBlockNum();
|
||||
}
|
||||
|
||||
if ASCEND_IS_AIV {
|
||||
coreIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
coreNum = get_block_num() * get_subblockdim();
|
||||
}
|
||||
|
||||
initBuffer(params);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~DispatchFFNCombineKernel()
|
||||
{
|
||||
}
|
||||
|
||||
template <int32_t CORE_TYPE = g_coreType>
|
||||
CATLASS_DEVICE
|
||||
void operator()(Params const ¶ms);
|
||||
|
||||
template <>
|
||||
CATLASS_DEVICE
|
||||
void operator()<AscendC::AIC>(Params const ¶ms)
|
||||
{
|
||||
GMM1(params);
|
||||
|
||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
||||
|
||||
GMM2(params);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
CATLASS_DEVICE
|
||||
void operator()<AscendC::AIV>(Params const ¶ms)
|
||||
{
|
||||
Dispatch(params);
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
||||
|
||||
Combine(params);
|
||||
}
|
||||
|
||||
private:
|
||||
CATLASS_DEVICE void initBuffer(Params const ¶ms) {
|
||||
workspaceInfo = WorkspaceInfo(params);
|
||||
peermemInfo = PeermemInfo(params, shmem);
|
||||
|
||||
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
|
||||
|
||||
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
|
||||
gmS.SetGlobalBuffer(params.ptrScale1);
|
||||
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
|
||||
|
||||
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
|
||||
gmS2.SetGlobalBuffer(params.ptrScale2);
|
||||
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
|
||||
|
||||
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
|
||||
gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2));
|
||||
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||
|
||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank + 8, params.expertPerRank);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
CATLASS_DEVICE void CopyGMToGM(
|
||||
AscendC::GlobalTensor<T> dst,
|
||||
AscendC::GlobalTensor<T> src,
|
||||
int32_t elemNum,
|
||||
int32_t ubMoveNum
|
||||
)
|
||||
{
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
|
||||
using TType = Gemm::GemmType<T, layout::RowMajor>;
|
||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||
CopyGmToUb copyGmToUb;
|
||||
CopyUbToGm copyUbToGm;
|
||||
constexpr int32_t BufferNum = 2;
|
||||
int tmpBufferSize = 32 * 1024 / sizeof(T); // 32 KB
|
||||
AscendC::LocalTensor<T> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<T>(0);
|
||||
tmpBuffer1.SetSize(tmpBufferSize);
|
||||
int tmpBufferOffset = 96 * 1024; // half of UB
|
||||
AscendC::LocalTensor<T> tmpBuffer2 = resource.ubBuf.template GetBufferByByte<T>(tmpBufferOffset);
|
||||
tmpBuffer2.SetSize(tmpBufferSize);
|
||||
|
||||
// [ReduceScatter] 2. Pre Interface Sync
|
||||
int pingpongId = 0;
|
||||
auto processCount = CeilDiv(elemNum, ubMoveNum);
|
||||
for (uint32_t processIndex = 0; processIndex < processCount; ++processIndex) {
|
||||
uint32_t curProcessNum = (processIndex == processCount - 1) ? elemNum - ubMoveNum * (processCount - 1) : ubMoveNum;
|
||||
AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1;
|
||||
AscendC::LocalTensor<T> buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2;
|
||||
auto processOffset = processIndex * ubMoveNum;
|
||||
|
||||
auto inputOffset = processOffset;
|
||||
auto outputOffset = processOffset;
|
||||
// [ReduceScatter] 2. Pre Interface Sync
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
|
||||
// [ReduceScatter] 3. Start shmem_mte_get_mem_nbi
|
||||
copyGmToUb(buf, src[inputOffset], layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum});
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
|
||||
copyUbToGm(dst[outputOffset], buf, layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum});
|
||||
|
||||
// [ReduceScatter] 4. Post Interface Sync
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
|
||||
pingpongId = (pingpongId + 1) % BufferNum;
|
||||
}
|
||||
// [ReduceScatter] 4. Post Interface Sync
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void GetCumsumForMMAIV(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP)
|
||||
{
|
||||
int32_t expertPerRankAligned = (expertPerRank + 8 - 1) / 8 * 8;
|
||||
AscendC::LocalTensor<int32_t> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::LocalTensor<int32_t> tmpResult = resource.ubBuf.template GetBufferByByte<int32_t>(EP * expertPerRank * sizeof(int32_t));
|
||||
#define U16(x) static_cast<uint16_t>(x)
|
||||
|
||||
AscendC::DataCopyPad(
|
||||
tmpBuffer1,
|
||||
tokenPerExpert[rankId * expertPerRank],
|
||||
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank + 8) * sizeof(int32_t)), 0},
|
||||
{}
|
||||
);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||
|
||||
for (uint32_t i = 1; i < EP; ++i) {
|
||||
AscendC::Add(tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[(i - 1) * expertPerRankAligned], expertPerRank);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||
|
||||
AscendC::DataCopyPad(
|
||||
result,
|
||||
tmpBuffer1,
|
||||
{U16(EP), U16((expertPerRank) * sizeof(int32_t)), 0, 0}
|
||||
);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void GMM1(Params const ¶ms){
|
||||
icache_preload(8);
|
||||
BlockScheduler blockScheduler;
|
||||
BlockMmad blockMmad(resource);
|
||||
|
||||
int64_t gmGroupOffsetA = 0;
|
||||
int64_t gmGroupOffsetB = 0;
|
||||
int64_t gmGroupOffsetC = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
uint32_t syncGroupIdx = 0;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(0); // 等待aiv计算cumsumformm
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (preCurrentmSum >= params.maxOutputSize) {
|
||||
currentM = 0;
|
||||
} else if (preCurrentmSum + currentM >= params.maxOutputSize) {
|
||||
currentM = params.maxOutputSize - preCurrentmSum;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementB> gmB1;
|
||||
gmB1.SetGlobalBuffer(params.ptrB1);
|
||||
if (currentM <= L1TileShape::M) {
|
||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
|
||||
LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK());
|
||||
LayoutB layoutB1 = params.layoutB1;
|
||||
LayoutScale layoutScale = params.layoutScale1;
|
||||
LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n());
|
||||
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
// Determine the starting loopIdx of the current core under the current groupIdx
|
||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||
// Loop through the matmul of each groupIdx
|
||||
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(0);
|
||||
}
|
||||
// Compute block location
|
||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
|
||||
// Compute initial location in logical coordinates
|
||||
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
|
||||
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
|
||||
MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N};
|
||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
|
||||
if (currentM > 0) {
|
||||
blockMmad(
|
||||
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||
gmB1[gmGroupOffsetB + gmOffsetB], layoutB1,
|
||||
gmC[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||
gmS[gmOffsetS], layoutScale,
|
||||
actualBlockShape
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if ((groupIdx + 1) == params.epilogueGranularity && (groupIdx < params.expertPerRank - 1)) {
|
||||
syncLoopIdx ++;
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(syncLoopIdx, 1);
|
||||
}
|
||||
|
||||
preCurrentmSum += currentM;
|
||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(syncLoopIdx + 1, 1);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void GMM2(Params const ¶ms) {
|
||||
icache_preload(8);
|
||||
BlockScheduler blockScheduler;
|
||||
BlockMmad blockMmad(resource);
|
||||
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
uint32_t k2 = params.problemShape.n() / 2;
|
||||
|
||||
int64_t gmGroupOffsetA = 0;
|
||||
int64_t gmGroupOffsetB = 0;
|
||||
int64_t gmGroupOffsetC = 0;
|
||||
|
||||
uint32_t startCoreIdx = 0;
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
int64_t preCurrentmSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
|
||||
if (params.epilogueGranularity < params.expertPerRank) {
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (preCurrentmSum >= params.maxOutputSize) {
|
||||
currentM = 0;
|
||||
} else if (preCurrentmSum + currentM > params.maxOutputSize) {
|
||||
currentM = params.maxOutputSize - preCurrentmSum;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementB> gmB2;
|
||||
gmB2.SetGlobalBuffer(params.ptrB2);
|
||||
if (currentM <= L1TileShape::M) {
|
||||
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||
}
|
||||
GemmCoord inGroupProblemShape{currentM, n2, k2}; // M N K
|
||||
|
||||
LayoutA layoutA = params.layoutA2.GetTileLayout(inGroupProblemShape.GetCoordMK());
|
||||
LayoutB layoutB2 = params.layoutB2;
|
||||
LayoutScale layoutScale = params.layoutScale2;
|
||||
LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n());
|
||||
|
||||
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
|
||||
// Determine the starting loopIdx of the current core under the current groupIdx
|
||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||
// Loop through the matmul of each groupIdx
|
||||
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(2);
|
||||
}
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
if (loopIdx + coreNum >= coreLoops) {
|
||||
syncLoopIdx = groupIdx;
|
||||
}
|
||||
|
||||
// Compute block location
|
||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
|
||||
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
|
||||
MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N};
|
||||
|
||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
|
||||
if (currentM > 0) {
|
||||
blockMmad(
|
||||
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
|
||||
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
|
||||
gmS2[gmOffsetS], layoutScale,
|
||||
actualBlockShape, syncLoopIdx, 3
|
||||
);
|
||||
}
|
||||
}
|
||||
preCurrentmSum += currentM;
|
||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
|
||||
}
|
||||
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
blockMmad.Finalize(params.expertPerRank - 1, 3);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||
uint64_t flag_offset = (shmem.SegmentSize() - MB_SIZE) / sizeof(int32_t);
|
||||
__gm__ int32_t* sync_base = shmem.SyncBaseAddr();
|
||||
int count = gm_load(sync_base) + 1;
|
||||
if (coreIdx < params.EP && coreIdx != params.rank) {
|
||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
|
||||
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
|
||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||
CopyGmToUb copyGmToUb;
|
||||
CopyUbToGm copyUbToGm;
|
||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
uint32_t tmp = params.EP * params.expertPerRank;
|
||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||
layout::RowMajor{ 1, tmp},
|
||||
layout::RowMajor{1, tmp});
|
||||
|
||||
tmpBuffer.SetValue(params.EP * params.expertPerRank, count);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
|
||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||
layout::RowMajor{ 1, tmp + 1},
|
||||
layout::RowMajor{1, tmp + 1});
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
|
||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(coreIdx, params.EP, 0);
|
||||
gm_signal_wait_until_eq_for_barrier(sync_check, count);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
gm_store(sync_base, count);
|
||||
}
|
||||
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Dispatch(Params const ¶ms) {
|
||||
icache_preload(8);
|
||||
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
||||
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // 把通信矩阵全部放到peermem
|
||||
uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
|
||||
|
||||
//---initRouting------
|
||||
moe_init_routing_quant_v2<ElementD2>(reinterpret_cast<GM_ADDR> (params.ptrA), params.expertIdx,
|
||||
params.moeInitRoutingQuantV2Scale, params.moeInitRoutingQuantV2Offset, shmem() + peermemInfo.offsetA,
|
||||
workspaceInfo.expandedRowIdx, localTokenPerExpert, params.expertTokensBeforeCapacity,
|
||||
shmem() + peermemInfo.offsetPeerPerTokenScale,
|
||||
params.ptrWorkspace + expandedRowIdxOffset,
|
||||
¶ms.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey);
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset);
|
||||
if (coreIdx == 0) {
|
||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
|
||||
|
||||
uint32_t curGroupOffset = 0;
|
||||
int32_t prevSumBeforeRank = 0;
|
||||
int32_t groupIdxDeq = 0;
|
||||
if (coreIdx < params.EP) {
|
||||
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
|
||||
prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i));
|
||||
}
|
||||
m_prevSumBeforeRank = prevSumBeforeRank;
|
||||
}
|
||||
int prevSum = prevSumBeforeRank;
|
||||
uint32_t prevGroupSum1 = 0;
|
||||
uint32_t dequantSum = 0;
|
||||
int32_t syncLoopIdx = -1;
|
||||
BlockEpilogue1 blockEpilogue(resource);
|
||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
// 第i个core从第i个rank的peermem读数据
|
||||
groupIdxDeq = groupIdx - 2;
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||
if (rowStart < params.maxOutputSize) {
|
||||
uint32_t rows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
if (rowStart + rows > params.maxOutputSize) {
|
||||
rows = params.maxOutputSize - rowStart;
|
||||
}
|
||||
uint32_t rowSrc = prevSum;
|
||||
prevSum += rows;
|
||||
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
|
||||
AscendC::GlobalTensor<ElementA> gmRemoteA;
|
||||
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
|
||||
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
|
||||
MatrixCoord offsetA{rowStart, 0};
|
||||
MatrixCoord shapeA{rows, params.problemShape.k()};
|
||||
MatrixCoord offsetPeer{rowSrc, 0};
|
||||
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
||||
// 通信Data
|
||||
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
||||
// 通信scale
|
||||
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
|
||||
}
|
||||
}
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
|
||||
syncLoopIdx++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V通知C当前轮的通信已完成
|
||||
|
||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||
uint32_t rowStartThisCore = 0;
|
||||
MatrixCoord offsetC{0U, 0};
|
||||
uint32_t dequantLen = prevGroupSum1 - dequantSum;
|
||||
if (dequantLen >= params.maxOutputSize) {
|
||||
dequantLen = dequantLen - params.maxOutputSize;
|
||||
}
|
||||
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
|
||||
}
|
||||
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
|
||||
dequantSum = 0;
|
||||
}
|
||||
}
|
||||
syncLoopIdx ++;
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
uint32_t lastDequantExpertNum = params.expertPerRank;
|
||||
if (params.epilogueGranularity < params.expertPerRank) {
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
if (lastDequantExpertNum < params.expertPerRank) {
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
|
||||
}
|
||||
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
|
||||
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
|
||||
MatrixCoord offsetC{rowStartThisCore, 0};
|
||||
uint32_t dequantLen = dequantSum;
|
||||
if (prevGroupSum1 >= params.maxOutputSize) {
|
||||
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
|
||||
}
|
||||
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
|
||||
LayoutC layoutC{dequantLen, params.problemShape.n()};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
|
||||
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
|
||||
}
|
||||
blockEpilogue.Finalize();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Combine(Params const ¶ms) {
|
||||
int32_t prevSumBeforeRank = 0;
|
||||
if (coreIdx < params.EP) {
|
||||
prevSumBeforeRank = m_prevSumBeforeRank;
|
||||
}
|
||||
|
||||
int prevSum = prevSumBeforeRank;
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
uint32_t k2 = params.problemShape.n() / 2;
|
||||
|
||||
// TODO 计算tokenperexpert的cumsum
|
||||
typename BlockEpilogue2::Params epilogueParams{
|
||||
static_cast<int32_t>(params.EP),
|
||||
static_cast<int32_t>(params.expertPerRank),
|
||||
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace)
|
||||
};
|
||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
||||
int32_t prevGroupSum2 = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3);
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
__gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx);
|
||||
AscendC::GlobalTensor<ElementD2> gmRemotePeer;
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr));
|
||||
uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2;
|
||||
if (srcRowOffset < params.maxOutputSize) {
|
||||
uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
if (srcRowOffset + dataRows > params.maxOutputSize) {
|
||||
dataRows = params.maxOutputSize - srcRowOffset;
|
||||
}
|
||||
uint32_t dstRowOffset = prevSum;
|
||||
prevSum += dataRows;
|
||||
MatrixCoord offsetC{srcRowOffset, 0};
|
||||
MatrixCoord offsetPeer{dstRowOffset, 0};
|
||||
MatrixCoord shapeC{dataRows, n2};
|
||||
int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC);
|
||||
int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]);
|
||||
} else {
|
||||
blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]);
|
||||
}
|
||||
}
|
||||
}
|
||||
prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||
}
|
||||
blockEpilogue.Finalize();
|
||||
AscendC::SyncAll<true>();
|
||||
shmem.CrossRankSync();
|
||||
MoeTokenUnpermuteTilingData tilingData;
|
||||
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum);
|
||||
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
|
||||
|
||||
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
|
||||
kernelMoeTokenUnpermuteOp.Process();
|
||||
}
|
||||
|
||||
private:
|
||||
struct WorkspaceInfo {
|
||||
GM_ADDR ptrA;
|
||||
GM_ADDR ptrPerTokenScale;
|
||||
GM_ADDR ptrcumsumMM;
|
||||
GM_ADDR ptrC;
|
||||
GM_ADDR ptrC2;
|
||||
GM_ADDR ptrPermutedToken;
|
||||
GM_ADDR ptrPerTokenScale2;
|
||||
GM_ADDR expandedRowIdx;
|
||||
GM_ADDR ptrTokenPerExpert;
|
||||
|
||||
CATLASS_DEVICE
|
||||
WorkspaceInfo(){}
|
||||
|
||||
CATLASS_DEVICE
|
||||
WorkspaceInfo(const Params & params) {
|
||||
uint32_t k2 = params.problemShape.n() / 2;
|
||||
uint32_t n2 = params.problemShape.k();
|
||||
int64_t workspaceOffset = 0;
|
||||
expandedRowIdx = params.ptrWorkspace;
|
||||
|
||||
workspaceOffset += AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
|
||||
ptrcumsumMM = params.ptrWorkspace + workspaceOffset;
|
||||
|
||||
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
|
||||
|
||||
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
|
||||
ptrPerTokenScale = params.ptrWorkspace + workspaceOffset;
|
||||
|
||||
workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale);
|
||||
ptrPerTokenScale2 = params.ptrWorkspace + workspaceOffset;
|
||||
|
||||
workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale);
|
||||
ptrTokenPerExpert = params.ptrWorkspace + workspaceOffset;
|
||||
|
||||
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
|
||||
ptrC = params.ptrWorkspace + workspaceOffset;
|
||||
ptrC2 = ptrC;
|
||||
|
||||
workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC),
|
||||
params.maxOutputSize * n2 * sizeof(ElementC));
|
||||
|
||||
ptrA = params.ptrWorkspace + workspaceOffset;
|
||||
ptrPermutedToken = ptrA;
|
||||
workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA),
|
||||
params.maxOutputSize * k2 * sizeof(ElementA));
|
||||
}
|
||||
};
|
||||
|
||||
struct PeermemInfo {
|
||||
int64_t offsetA;
|
||||
int64_t offsetPeerPerTokenScale;
|
||||
int64_t offsetPeerTokenPerExpert;
|
||||
int64_t offsetD;
|
||||
|
||||
CATLASS_DEVICE
|
||||
PeermemInfo(){}
|
||||
|
||||
CATLASS_DEVICE
|
||||
PeermemInfo(const Params & params, const HcclShmem & shmem) {
|
||||
offsetA = 0; // 占用1/3的BUFFSIZE
|
||||
offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); // 占用1MB
|
||||
offsetD = offsetPeerPerTokenScale + MB_SIZE; // 占用剩下的
|
||||
offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; // 占用最后2MB
|
||||
}
|
||||
};
|
||||
|
||||
Arch::Resource<ArchTag> resource;
|
||||
|
||||
uint32_t coreIdx;
|
||||
uint32_t coreNum;
|
||||
|
||||
Params params;
|
||||
WorkspaceInfo workspaceInfo;
|
||||
PeermemInfo peermemInfo;
|
||||
|
||||
int64_t m_prevSumBeforeRank;
|
||||
|
||||
AscendC::GlobalTensor<ElementA> gmA;
|
||||
AscendC::GlobalTensor<ElementC> gmC;
|
||||
AscendC::GlobalTensor<ElementScale> gmS;
|
||||
|
||||
AscendC::GlobalTensor<ElementD1> gmPermutedToken;
|
||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||
AscendC::GlobalTensor<ElementC> gmC2;
|
||||
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale1;
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale2;
|
||||
|
||||
AscendC::GlobalTensor<int32_t> tokenPerExpert;
|
||||
AscendC::GlobalTensor<int32_t> cumsumMM;
|
||||
Layout3D tokenPerExpertLayout;
|
||||
HcclShmem shmem;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Kernel
|
||||
|
||||
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP
|
||||
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dispatch_ffn_combine_tiling.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h"
|
||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
||||
|
||||
#ifndef ASCENDC_DISPATCH_FFN_COMBINE_TILING_H
|
||||
#define ASCENDC_DISPATCH_FFN_COMBINE_TILING_H
|
||||
struct DispatchFFNCombineInfo {
|
||||
uint32_t M;
|
||||
uint32_t K;
|
||||
uint32_t N;
|
||||
uint32_t expertPerRank;
|
||||
uint32_t maxOutputSize;
|
||||
uint32_t isTransposeB;
|
||||
uint32_t isWeightNz;
|
||||
uint32_t aivNum;
|
||||
uint32_t totalUbSize;
|
||||
uint32_t topK;
|
||||
uint32_t worldSize;
|
||||
};
|
||||
|
||||
struct CoCTiling {
|
||||
int32_t m0 = -1;
|
||||
int32_t k0 = -1;
|
||||
int32_t n0 = -1;
|
||||
int32_t swizzleDirect = -1;
|
||||
int32_t swizzleOffset = -1;
|
||||
int32_t ubMoveNum = -1;
|
||||
int32_t pValue = -1;
|
||||
int32_t commNpuSplit = -1;
|
||||
int32_t commDataSplit = -1;
|
||||
int32_t lenPerLoop = -1;
|
||||
uint64_t initRoutingQuantTilingKey;
|
||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
||||
};
|
||||
|
||||
struct DispatchFFNCombineTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling;
|
||||
DispatchFFNCombineInfo dispatchFFNCombineInfo;
|
||||
CoCTiling cocTiling;
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_init_routing_quant_v2.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "moe_v2_sort_one_core.h"
|
||||
#include "moe_v2_sort_multi_core.h"
|
||||
#include "moe_v2_mrgsort_out.h"
|
||||
#include "moe_v2_mrgsort.h"
|
||||
#include "moe_v2_expert_token_out.h"
|
||||
#include "moe_v2_src_to_dst_op.h"
|
||||
#include "moe_v2_src_to_dst_with_capacity.h"
|
||||
#include "moe_v2_fullload_quant.h"
|
||||
#include "moe_v2_fullload_dynamic_quant.h"
|
||||
#include "moe_v2_gather_quant.h"
|
||||
#include "moe_v2_gather_dynamic_quant.h"
|
||||
#include "moe_v2_src_to_dst_and_gather.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace MoeInitRoutingQuantV2;
|
||||
using namespace optiling;
|
||||
|
||||
template <class DTYPE_X = bfloat16_t>
|
||||
__aicore__ inline void moe_init_routing_quant_v2(
|
||||
GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, GM_ADDR dynamicQuantScale, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, uint64_t tilingKey) {
|
||||
|
||||
if (g_coreType == AIC) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (workspace == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (tilingKey == 20000) { // quant full load
|
||||
TPipe sortPipe;
|
||||
MoeV2FullLoadQuant<DTYPE_X> op;
|
||||
op.Init(x, expertIdx, scale, offset, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, &sortPipe);
|
||||
op.Process();
|
||||
sortPipe.Destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
else if (tilingKey == 21000) { // dynamic quant full load
|
||||
TPipe sortPipe;
|
||||
MoeV2FullLoadDynamicQuant<DTYPE_X> op;
|
||||
op.Init(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, scale, dynamicQuantScale, workspace, tilingData,
|
||||
&sortPipe);
|
||||
op.Process();
|
||||
sortPipe.Destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
// sort
|
||||
if (tilingKey == 10000 || tilingKey == 10100 || tilingKey == 11000 || tilingKey == 11100) {
|
||||
TPipe sortPipe;
|
||||
MoeV2SortOneCore op;
|
||||
op.Init<MoeInitRoutingQuantV2TilingData>(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace,
|
||||
tilingData, &sortPipe);
|
||||
op.Process();
|
||||
sortPipe.Destroy();
|
||||
} else if (tilingKey == 10010 || tilingKey == 10110 || tilingKey == 11010 || tilingKey== 11110) {
|
||||
TPipe sortPipe;
|
||||
MoeV2SortMultiCore op;
|
||||
op.Init<MoeInitRoutingQuantV2TilingData>(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace,
|
||||
tilingData, &sortPipe);
|
||||
op.Process();
|
||||
sortPipe.Destroy();
|
||||
}
|
||||
|
||||
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey ==11000 || tilingKey ==11010) { //没有drop的情况
|
||||
if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) {
|
||||
TPipe expertTokenOutPipe;
|
||||
MoeV2ExpertTokenOut expertTokenOutOp;
|
||||
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
|
||||
expandedRowIdx, workspace, tilingData, &expertTokenOutPipe);
|
||||
expertTokenOutOp.Process();
|
||||
expertTokenOutPipe.Destroy();
|
||||
}
|
||||
TPipe srcToDstPipe;
|
||||
MoeV2SrcToDstOp srcToDstOp;
|
||||
srcToDstOp.Init<MoeInitRoutingQuantV2TilingData>(expandedRowIdx, workspace, tilingData, &srcToDstPipe);
|
||||
srcToDstOp.Process();
|
||||
srcToDstPipe.Destroy();
|
||||
} else if (tilingKey ==10100 || tilingKey ==10110 || tilingKey ==11100 || tilingKey ==11110) { //有drop的情况
|
||||
TPipe expertTokenOutPipe;
|
||||
MoeV2ExpertTokenOut expertTokenOutOp;
|
||||
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
|
||||
expandedRowIdx, workspace, tilingData, &expertTokenOutPipe);
|
||||
expertTokenOutOp.Process();
|
||||
expertTokenOutPipe.Destroy();
|
||||
|
||||
if (tilingKey == 10100 || tilingKey == 10110) {
|
||||
TPipe srcToDstPipe;
|
||||
MoeV2SrcToDstWithCapacity<int8_t, MoeInitRoutingQuantV2TilingData> srcToDstWithCapacityOp;
|
||||
srcToDstWithCapacityOp.Init(expandedRowIdx, expandedX, workspace, tilingData, &srcToDstPipe);
|
||||
srcToDstWithCapacityOp.Process();
|
||||
srcToDstPipe.Destroy();
|
||||
} else {
|
||||
TPipe srcToDstGatherPipe;
|
||||
MoeV2SrcToDstAndGather<DTYPE_X, MoeInitRoutingQuantV2TilingData> srcToDstAndGatherOp;
|
||||
srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe);
|
||||
srcToDstAndGatherOp.Process();
|
||||
srcToDstGatherPipe.Destroy();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey == 10100 || tilingKey == 10110) {
|
||||
TPipe gatherPipe;
|
||||
MoeV2GatherQuant<DTYPE_X> gatherQuantOp;
|
||||
gatherQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, workspace, tilingData, &gatherPipe);
|
||||
gatherQuantOp.Process();
|
||||
gatherPipe.Destroy();
|
||||
} else if (tilingKey == 11000 || tilingKey == 11010) {
|
||||
TPipe gatherPipe;
|
||||
MoeV2GatherDynamicQuant<DTYPE_X> gatherDynamicQuantOp;
|
||||
gatherDynamicQuantOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &gatherPipe);
|
||||
gatherDynamicQuantOp.Process();
|
||||
gatherPipe.Destroy();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
#pragma once
|
||||
#include "moe_init_routing_v2_tiling.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
const static int64_t ATTR_QUANT_MODE = 6;
|
||||
const static int64_t TILING_KEY_BASE = 10000;
|
||||
const static int64_t TILING_KEY_PERF_BASE = 20000;
|
||||
const static int64_t TILING_KEY_QUANT_BASE = 1000;
|
||||
const static int64_t TILING_KEY_DROP_MODE_BASE = 100;
|
||||
const static int64_t TILING_KEY_SORT_BASE = 10;
|
||||
const static int64_t FOUR_BLOCK_BYTE = 128;
|
||||
const static int64_t MAX_COLS_ONE_LOOP_QUANT = 8192;
|
||||
const static int64_t INDEX_SCALE = 2;
|
||||
const static int64_t INDEX_OFFSET = 3;
|
||||
const static int64_t SMOOTH_NONE = 0;
|
||||
const static int64_t SMOOTH_1H = 1;
|
||||
const static int64_t SMOOTH_EH = 2;
|
||||
const static int64_t MAX_COLS_DYNAMIC_QUANT = 6144;
|
||||
const static int64_t DYNAMIC_QUANT_SRC_TO_DST_BUFFER = 15;
|
||||
const static int64_t DYNAMIC_QUANT_COLS_BUFFER = 21;
|
||||
const static int64_t DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER = 13;
|
||||
const static int64_t DYNAMIC_QUANT_SCALE_SIZE_64 = 64;
|
||||
const static int64_t DYNAMIC_QUANT_SCALE_SIZE_128 = 128;
|
||||
const static int64_t OUTOUT_DYNAMIC_QUANT_SCALE = 4;
|
||||
const static int64_t FULLLOAD_H_LIMIT = 7168;
|
||||
|
||||
|
||||
inline static int64_t AlignOneBlockByte(int64_t x) {
|
||||
return (x + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
}
|
||||
|
||||
inline static int64_t AlignOneBlockByteCeil(int64_t x) {
|
||||
return x / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
}
|
||||
|
||||
struct MoeInitRoutingQuantV2TilingData {
|
||||
int64_t coreNum;
|
||||
int64_t n;
|
||||
int64_t cols;
|
||||
int64_t k;
|
||||
int64_t expertCapacity;
|
||||
int64_t expertNum;
|
||||
int64_t dropPadMode;
|
||||
int64_t expertTokensCountOrCumsumFlag;
|
||||
int64_t expertTokensBeforeCapacityFlag;
|
||||
int64_t smoothType;
|
||||
InnerMoeV2VBSComputeTilingData vbsComputeParamsOp;
|
||||
InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp;
|
||||
InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class MoeInitRoutingQuantV2TilingBase : public InnerMoeInitRoutingV2TilingBase {
|
||||
public:
|
||||
protected:
|
||||
|
||||
bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override;
|
||||
uint64_t GetTilingKey() const override;
|
||||
bool GetWorkspaceSize() override;
|
||||
bool PostTiling() override;
|
||||
public:
|
||||
//bool CheckOutShape() override;
|
||||
bool IsFullLoadQuant(int64_t space);
|
||||
bool IsFullLoadDynamicQuant(int64_t space);
|
||||
bool IsFullLoad() override;
|
||||
void SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows, int64_t lastCoreRows,
|
||||
int64_t cols);
|
||||
void SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t baseMaxCols, int64_t cols);
|
||||
void SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows,
|
||||
int64_t lastCoreRows, int64_t basePerLoopMaxRows);
|
||||
void Tiling4GatherQuant();
|
||||
void Tiling4GatherDynamicQuant();
|
||||
void Tiling4SrcToDstCapacityCompute() override;
|
||||
void Tiling4GatherOutCompute() override;
|
||||
void CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst, InnerMoeV2GatherOutComputeTilingData& src);
|
||||
void CopyTilingData();
|
||||
|
||||
|
||||
int64_t quantMode;
|
||||
MoeInitRoutingQuantV2TilingData quantTilingData;
|
||||
};
|
||||
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::IsFullLoadQuant(int64_t space) {
|
||||
int64_t perCoreXRows = moeInitRoutingTilingData.n / aivNum;
|
||||
int64_t remainder = moeInitRoutingTilingData.n % aivNum;
|
||||
// NUM_TWO is Max xRows need add 2 becauseof the left and right row may be another row.
|
||||
perCoreXRows = remainder <= 1 ? perCoreXRows + 1 : perCoreXRows + NUM_TWO;
|
||||
int64_t quantBaseSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols);
|
||||
int64_t quantSpace =
|
||||
quantBaseSpace * (inuptXDtypeSize_ + sizeof(int8_t) + sizeof(float) + sizeof(int16_t)) * perCoreXRows;
|
||||
int64_t remainUbAfterSort = aicoreParams_.ubSize - space - quantSpace;
|
||||
return remainUbAfterSort > 0;
|
||||
}
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::IsFullLoadDynamicQuant(int64_t space) {
|
||||
int64_t quantSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols) * DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER;
|
||||
int64_t scaleOutSpace = 64;
|
||||
int64_t remainUbAfterSort = aicoreParams_.ubSize - space - scaleOutSpace - quantSpace;
|
||||
return remainUbAfterSort > 0;
|
||||
}
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::IsFullLoad() {
|
||||
if (totalLength > sortLoopMaxElement || moeInitRoutingTilingData.cols > MAX_COLS_ONE_LOOP_QUANT ||
|
||||
this->dropPadMode == 1) {
|
||||
return false;
|
||||
}
|
||||
int64_t sortSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * ONE_CORE_SORT_BUFFER;
|
||||
int64_t otherSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * NUM_THREE;
|
||||
int64_t expertSpace = AlignOneBlockByte(this->expertNum * sizeof(int32_t));
|
||||
if (quantMode == 0) {
|
||||
return IsFullLoadQuant(sortSpace + otherSpace + expertSpace);
|
||||
} else {
|
||||
return IsFullLoadDynamicQuant(sortSpace + otherSpace + expertSpace);
|
||||
}
|
||||
}
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) {
|
||||
|
||||
InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode,
|
||||
expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0);
|
||||
this -> quantMode = quantMode;
|
||||
if (quantMode == 0) {
|
||||
} else {
|
||||
if (scaleDim0 > 0) {
|
||||
quantTilingData.smoothType = ((scaleDim0 == 1) ? SMOOTH_1H : SMOOTH_EH);
|
||||
} else {
|
||||
quantTilingData.smoothType = SMOOTH_NONE;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
uint64_t MoeInitRoutingQuantV2TilingBase::GetTilingKey() const {
|
||||
if (isFullLoad) {
|
||||
return TILING_KEY_PERF_BASE + quantMode * TILING_KEY_QUANT_BASE;
|
||||
}
|
||||
return TILING_KEY_BASE + quantMode * TILING_KEY_QUANT_BASE + dropPadMode * TILING_KEY_DROP_MODE_BASE +
|
||||
(totalLength > sortLoopMaxElement) * TILING_KEY_SORT_BASE;
|
||||
}
|
||||
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::PostTiling() {
|
||||
CopyTilingData();
|
||||
return true;
|
||||
}
|
||||
void MoeInitRoutingQuantV2TilingBase::CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst,
|
||||
InnerMoeV2GatherOutComputeTilingData& src) {
|
||||
dst.needCoreNum = (src.needCoreNum);
|
||||
dst.activateRows = (src.activateRows);
|
||||
dst.perCoreRows = (src.perCoreRows);
|
||||
dst.perCorePerLoopRows = (src.perCorePerLoopRows);
|
||||
dst.perCoreLastLoopRows = (src.perCoreLastLoopRows);
|
||||
dst.lastCoreRows = (src.lastCoreRows);
|
||||
dst.lastCorePerLoopRows = (src.lastCorePerLoopRows);
|
||||
dst.lastCoreLastLoopRows = (src.lastCoreLastLoopRows);
|
||||
dst.perCoreLoops = (src.perCoreLoops);
|
||||
dst.lastCoreLoops = (src.lastCoreLoops);
|
||||
dst.perLoopCols = (src.perLoopCols);
|
||||
dst.lastLoopCols = (src.lastLoopCols);
|
||||
dst.colLoops = (src.colLoops);
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::CopyTilingData() {
|
||||
quantTilingData.coreNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.coreNum);
|
||||
quantTilingData.n = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.n);
|
||||
quantTilingData.cols = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols);
|
||||
quantTilingData.k = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.k);
|
||||
quantTilingData.expertCapacity = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertCapacity);
|
||||
quantTilingData.expertNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertNum);
|
||||
quantTilingData.dropPadMode = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.dropPadMode);
|
||||
quantTilingData.expertTokensCountOrCumsumFlag = (
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensCountOrCumsumFlag);
|
||||
quantTilingData.expertTokensBeforeCapacityFlag = (
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensBeforeCapacityFlag);
|
||||
|
||||
auto vbsTilingData = &InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vbsComputeParamsOp;
|
||||
quantTilingData.vbsComputeParamsOp.needCoreNum = (vbsTilingData->needCoreNum);
|
||||
quantTilingData.vbsComputeParamsOp.perCoreElements = (vbsTilingData->perCoreElements);
|
||||
quantTilingData.vbsComputeParamsOp.perCoreLoops = (vbsTilingData->perCoreLoops);
|
||||
quantTilingData.vbsComputeParamsOp.perCorePerLoopElements = (vbsTilingData->perCorePerLoopElements);
|
||||
quantTilingData.vbsComputeParamsOp.perCoreLastLoopElements = (vbsTilingData->perCoreLastLoopElements);
|
||||
quantTilingData.vbsComputeParamsOp.lastCoreElements = (vbsTilingData->lastCoreElements);
|
||||
quantTilingData.vbsComputeParamsOp.lastCoreLoops = (vbsTilingData->lastCoreLoops);
|
||||
quantTilingData.vbsComputeParamsOp.lastCorePerLoopElements = (vbsTilingData->lastCorePerLoopElements);
|
||||
quantTilingData.vbsComputeParamsOp.lastCoreLastLoopElements = (vbsTilingData->lastCoreLastLoopElements);
|
||||
quantTilingData.vbsComputeParamsOp.oneLoopMaxElements = (vbsTilingData->oneLoopMaxElements);
|
||||
|
||||
quantTilingData.vmsMiddleComputeParamsOp.needCoreNum = (
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vmsMiddleComputeParamsOp.needCoreNum);
|
||||
quantTilingData.sortOutComputeParamsOp.oneLoopMaxElements = (
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.sortOutComputeParamsOp.oneLoopMaxElements);
|
||||
|
||||
CopyGatherOutTiling(quantTilingData.srcToDstComputeParamsOp,
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstComputeParamsOp);
|
||||
CopyGatherOutTiling(quantTilingData.srcToDstCapacityComputeParamsOp,
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp);
|
||||
}
|
||||
|
||||
|
||||
bool MoeInitRoutingQuantV2TilingBase::GetWorkspaceSize() {
|
||||
InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize();
|
||||
bool useCols =
|
||||
(dropPadMode == 0 && quantTilingData.gatherOutComputeParamsOp.colLoops > 1) ||
|
||||
(dropPadMode == 1 &&
|
||||
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.colLoops > 1);
|
||||
if (quantMode == 1 && useCols) {
|
||||
workspaceSize_ += aivNum * InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols * sizeof(float);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData,
|
||||
int64_t perCoreRows, int64_t lastCoreRows, int64_t cols) {
|
||||
tilingData->perCorePerLoopRows = perCoreRows;
|
||||
tilingData->perCoreLastLoopRows = perCoreRows;
|
||||
tilingData->lastCorePerLoopRows = lastCoreRows;
|
||||
tilingData->lastCoreLastLoopRows = lastCoreRows;
|
||||
tilingData->perCoreLoops = 1;
|
||||
tilingData->lastCoreLoops = 1;
|
||||
tilingData->perLoopCols = cols;
|
||||
tilingData->lastLoopCols = cols;
|
||||
tilingData->colLoops = 1;
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData,
|
||||
int64_t baseMaxCols, int64_t cols) {
|
||||
tilingData->perLoopCols = (std::min(baseMaxCols, cols));
|
||||
tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols));
|
||||
tilingData->colLoops = (baseMaxCols == 0 ? 0 : (cols + baseMaxCols - 1) / baseMaxCols);
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData,
|
||||
int64_t perCoreRows, int64_t lastCoreRows,
|
||||
int64_t basePerLoopMaxRows) {
|
||||
tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows));
|
||||
tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows));
|
||||
tilingData->perCoreLoops = (basePerLoopMaxRows == 0 ? 0
|
||||
: (perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
|
||||
tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows));
|
||||
tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows));
|
||||
tilingData->lastCoreLoops = (basePerLoopMaxRows == 0 ? 0
|
||||
: (lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::Tiling4SrcToDstCapacityCompute() {
|
||||
if (quantMode == 0 || dropPadMode == 0) {
|
||||
InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute();
|
||||
return;
|
||||
}
|
||||
|
||||
auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp;
|
||||
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
|
||||
if (perCoreRows <= 0) {
|
||||
tilingData->needCoreNum = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
tilingData->needCoreNum = CeilDiv(totalLength, perCoreRows);
|
||||
int64_t cols = moeInitRoutingTilingData.cols;
|
||||
tilingData->perCoreRows = perCoreRows;
|
||||
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
||||
tilingData->lastCoreRows = lastCoreRows;
|
||||
|
||||
int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR;
|
||||
int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
|
||||
int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64;
|
||||
if (rowSize + colSize + scaleSize < static_cast<int64_t>(aicoreParams_.ubSize)) {
|
||||
|
||||
SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols);
|
||||
} else {
|
||||
|
||||
int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT;
|
||||
int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
|
||||
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
|
||||
int64_t basePerLoopMaxRows =
|
||||
AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
|
||||
if (cols < MAX_COLS_DYNAMIC_QUANT) {
|
||||
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
|
||||
} else if (perCoreRows < basePerLoopMaxRows) {
|
||||
baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
|
||||
}
|
||||
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
|
||||
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherQuant() {
|
||||
auto tilingData = &quantTilingData.gatherOutComputeParamsOp;
|
||||
tilingData->activateRows = totalLength;
|
||||
if (dropPadMode == 0 && activateNum > 0) {
|
||||
tilingData->activateRows = (std::min(activateNum, totalLength));
|
||||
}
|
||||
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
|
||||
|
||||
if (perCoreRows <= 0) {
|
||||
tilingData->needCoreNum = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows));
|
||||
int64_t cols = moeInitRoutingTilingData.cols;
|
||||
tilingData->perCoreRows = perCoreRows;
|
||||
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
||||
tilingData->lastCoreRows = lastCoreRows;
|
||||
int64_t sizeOfCol = sizeof(int8_t) * NUM_TWO + sizeof(float) + sizeof(int16_t) + inuptXDtypeSize_ * NUM_TWO;
|
||||
int64_t rowSize = AlignOneBlockByte((perCoreRows * sizeof(int32_t) * NUM_TWO));
|
||||
int64_t colSize = AlignOneBlockByte(cols * sizeOfCol);
|
||||
if (rowSize + colSize < static_cast<int64_t>(aicoreParams_.ubSize) / NUM_TWO) {
|
||||
SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols);
|
||||
} else {
|
||||
int64_t baseMaxCols = MAX_COLS_ONE_LOOP_QUANT;
|
||||
int64_t baseMaxColsSize = AlignOneBlockByte(baseMaxCols * sizeOfCol);
|
||||
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
|
||||
int64_t basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - baseMaxColsSize) / NUM_TWO / sizeof(int32_t));
|
||||
if (cols < MAX_COLS_ONE_LOOP_QUANT) {
|
||||
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize) / NUM_TWO / sizeof(int32_t));
|
||||
} else if (perCoreRows < basePerLoopMaxRows) {
|
||||
baseMaxCols = AlignOneBlockByteCeil((ubSize - rowSize) / sizeOfCol);
|
||||
}
|
||||
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
|
||||
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void SetGatherTilingDatawithloop(InnerMoeV2GatherOutComputeTilingData* tilingData,
|
||||
int64_t perCorePerLoopRows, int64_t lastCorePerLoopRows, int64_t cols,
|
||||
int64_t perCoreLastLoopRows = 1, int64_t lastCoreLastLoopRows = 1,
|
||||
int64_t perCoreLoops = 1, int64_t lastCoreLoops = 1) {
|
||||
tilingData-> perCorePerLoopRows = perCorePerLoopRows;
|
||||
tilingData-> perCoreLastLoopRows = perCoreLastLoopRows;
|
||||
tilingData-> lastCorePerLoopRows = lastCorePerLoopRows;
|
||||
tilingData-> lastCoreLastLoopRows = lastCoreLastLoopRows;
|
||||
tilingData-> perCoreLoops = perCoreLoops;
|
||||
tilingData-> lastCoreLoops = lastCoreLoops;
|
||||
tilingData-> perLoopCols = cols;
|
||||
tilingData-> lastLoopCols = cols;
|
||||
tilingData-> colLoops = 1;
|
||||
}
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherDynamicQuant() {
|
||||
|
||||
auto tilingData = &quantTilingData.gatherOutComputeParamsOp;
|
||||
tilingData->activateRows = totalLength;
|
||||
if (dropPadMode == 0 && activateNum > 0) {
|
||||
tilingData->activateRows = (std::min(activateNum, totalLength));
|
||||
}
|
||||
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
|
||||
|
||||
if (perCoreRows <= 0) {
|
||||
tilingData->needCoreNum = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows));
|
||||
|
||||
int64_t cols = InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols;
|
||||
|
||||
tilingData->perCoreRows = perCoreRows;
|
||||
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
||||
tilingData->lastCoreRows = lastCoreRows;
|
||||
|
||||
|
||||
int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR;
|
||||
int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER;
|
||||
int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64;
|
||||
int64_t onceRowSize = (static_cast<int64_t>(aicoreParams_.ubSize) -
|
||||
colSize - scaleSize -
|
||||
ONE_BLOCK_BYTE * NUM_FOUR * NUM_THREE) /
|
||||
(sizeof(int32_t) * NUM_FOUR);
|
||||
int64_t oneBlockNumInt = static_cast<int64_t>(ONE_BLOCK_BYTE) / static_cast<int64_t>(sizeof(int32_t));
|
||||
onceRowSize = onceRowSize / oneBlockNumInt * oneBlockNumInt;
|
||||
bool ifOneLoop = ((static_cast<int64_t>(aicoreParams_.ubSize) > colSize +
|
||||
scaleSize + ONE_BLOCK_BYTE * NUM_FOUR * NUM_FOUR) &&
|
||||
quantTilingData.smoothType == SMOOTH_NONE &&
|
||||
cols == FULLLOAD_H_LIMIT);
|
||||
|
||||
int64_t perCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, perCoreRows) : perCoreRows;
|
||||
int64_t lastCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, lastCoreRows) : lastCoreRows;
|
||||
int64_t perCoreLoops = ifOneLoop ? CeilDiv(perCoreRows, perCoreOnceRowSize) : 1;
|
||||
int64_t lastCoreLoops = ifOneLoop ? CeilDiv(lastCoreRows, lastCoreOnceRowSize) : 1;
|
||||
int64_t perCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(perCoreRows, perCoreOnceRowSize) : perCoreRows;
|
||||
int64_t lastCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(lastCoreRows, lastCoreOnceRowSize) : lastCoreRows;
|
||||
|
||||
if (rowSize + colSize + scaleSize < static_cast<int64_t>(aicoreParams_.ubSize) || ifOneLoop) {
|
||||
|
||||
SetGatherTilingDatawithloop(tilingData, perCoreOnceRowSize, lastCoreOnceRowSize, cols,
|
||||
perCoreLastLoopRows, lastCoreLastLoopRows,
|
||||
perCoreLoops, lastCoreLoops);
|
||||
} else {
|
||||
int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT;
|
||||
int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER;
|
||||
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
|
||||
int64_t basePerLoopMaxRows =
|
||||
AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
|
||||
if (cols < MAX_COLS_DYNAMIC_QUANT) {
|
||||
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
|
||||
} else if (perCoreRows < basePerLoopMaxRows) {
|
||||
baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_COLS_BUFFER;
|
||||
}
|
||||
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
|
||||
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherOutCompute() {
|
||||
if (quantMode == 0) {
|
||||
Tiling4GatherQuant();
|
||||
} else {
|
||||
Tiling4GatherDynamicQuant();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,410 @@
|
||||
#pragma once
|
||||
|
||||
#include "tiling_base.h"
|
||||
|
||||
|
||||
namespace optiling {
|
||||
const static int64_t TILING_KEY_DROPLESS_SORT_ONE_CORE = 10001;
|
||||
const static int64_t TILING_KEY_DROPLESS_SORT_MULTI_CORE = 10002;
|
||||
const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE = 10011;
|
||||
const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE = 10012;
|
||||
const static int64_t TILING_KEY_HIGH_PERFORMANCE = 20000;
|
||||
const static int64_t NUM_TWO = 2;
|
||||
const static int64_t NUM_THREE = 3;
|
||||
const static int64_t NUM_FOUR = 4;
|
||||
const static int64_t MRG_LIST_NUM = 4;
|
||||
const static int64_t SORT32_ALIGN_ELEMENT = 32;
|
||||
const static int64_t ONE_BLOCK_BYTE = 32;
|
||||
const static size_t DIM_ONE = 1;
|
||||
const static size_t DIM_TWO = 2;
|
||||
const static size_t DIM_THREE = 3;
|
||||
const static int32_t SIZE_16 = 16;
|
||||
const static int32_t LENGTH_1024 = 1024;
|
||||
const static int64_t MAX_COLS_ONE_LOOP = 16376;
|
||||
const static int64_t ASSIST_NUM = 256;
|
||||
const static int64_t INDEX_INPUT_X = 0;
|
||||
const static int64_t INDEX_INPUT_EXPERT_IDX = 1;
|
||||
const static int64_t ATTR_ACTIVE_ROWS = 0;
|
||||
const static int64_t ATTR_EXPERT_CAPACITY = 1;
|
||||
const static int64_t ATTR_EXPERT_NUM = 2;
|
||||
const static int64_t ATTR_DROP_PAD_MODE = 3;
|
||||
const static int64_t ATTR_EXPERT_TOKENS_COUNT_OR_CUMSUM_FLAG = 4;
|
||||
const static int64_t ATTR_EXPERT_TOKENS_BEFORE_CAPACITY_FLAG = 5;
|
||||
const static int64_t OUTOUT_EXPANDED_X = 0;
|
||||
const static int64_t OUTOUT_EXPANDED_ROW_IDX = 1;
|
||||
const static int64_t OUTOUT_EXPERT_TOKENS_COUNT_OR_CUMSUM = 2;
|
||||
const static int64_t OUTOUT_EXPERT_TOKENS_BEFORE_CAPACITY = 3;
|
||||
const static int64_t KV_FACTOR = 2;
|
||||
const static int64_t ONE_CORE_SORT_BUFFER = 6;
|
||||
const static int64_t EXPERT_TOKENS_COUNT = 2;
|
||||
|
||||
|
||||
inline static int64_t CeilLog4(int64_t x) {
|
||||
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(NUM_FOUR)));
|
||||
}
|
||||
|
||||
inline static int64_t GetPerOrLastValue(int64_t x, int64_t y) {
|
||||
if (y == 0) {
|
||||
return 0;
|
||||
}
|
||||
return x <= y ? x : x % y;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr T CeilDiv(const T dividend, const T divisor)
|
||||
{
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
|
||||
struct InnerMoeV2VBSComputeTilingData {
|
||||
int64_t needCoreNum = 0;
|
||||
int64_t perCoreElements = 0;
|
||||
int64_t perCoreLoops = 0;
|
||||
int64_t perCorePerLoopElements = 0;
|
||||
int64_t perCoreLastLoopElements = 0;
|
||||
int64_t lastCoreElements = 0;
|
||||
int64_t lastCoreLoops = 0;
|
||||
int64_t lastCorePerLoopElements = 0;
|
||||
int64_t lastCoreLastLoopElements = 0;
|
||||
int64_t oneLoopMaxElements = 0;
|
||||
};
|
||||
|
||||
struct InnerMoeV2VMSMiddleComputeTilingData {
|
||||
int64_t needCoreNum = 0;
|
||||
};
|
||||
|
||||
struct InnerMoeV2SortOutComputeTilingData {
|
||||
int64_t oneLoopMaxElements = 0;
|
||||
};
|
||||
|
||||
struct InnerMoeV2GatherOutComputeTilingData {
|
||||
int64_t needCoreNum = 0;
|
||||
int64_t activateRows = 0;
|
||||
int64_t perCoreRows = 0;
|
||||
int64_t perCorePerLoopRows = 0;
|
||||
int64_t perCoreLastLoopRows = 0;
|
||||
int64_t lastCoreRows = 0;
|
||||
int64_t lastCorePerLoopRows = 0;
|
||||
int64_t lastCoreLastLoopRows = 0;
|
||||
int64_t perCoreLoops = 0;
|
||||
int64_t lastCoreLoops = 0;
|
||||
int64_t perLoopCols = 0;
|
||||
int64_t lastLoopCols = 0;
|
||||
int64_t colLoops = 0;
|
||||
};
|
||||
|
||||
struct InnerMoeInitRoutingV2TilingData {
|
||||
int64_t coreNum;
|
||||
int64_t n;
|
||||
int64_t cols;
|
||||
int64_t k;
|
||||
int64_t expertCapacity;
|
||||
int64_t expertNum;
|
||||
int64_t dropPadMode;
|
||||
int64_t expertTokensCountOrCumsumFlag;
|
||||
int64_t expertTokensBeforeCapacityFlag;
|
||||
InnerMoeV2VBSComputeTilingData vbsComputeParamsOp;
|
||||
InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp;
|
||||
InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp;
|
||||
InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp;
|
||||
};
|
||||
|
||||
|
||||
class InnerMoeInitRoutingV2TilingBase : public TilingBaseClass {
|
||||
|
||||
protected:
|
||||
bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) override;
|
||||
bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override;
|
||||
|
||||
bool DoOpTiling() override;
|
||||
uint64_t GetTilingKey() const override;
|
||||
bool GetWorkspaceSize() override;
|
||||
|
||||
|
||||
protected:
|
||||
bool CheckTokenCount(int64_t num, const char* tag);
|
||||
|
||||
virtual void Tiling4GatherOutCompute() = 0;
|
||||
void Tiling4SrcToDstCompute();
|
||||
virtual void Tiling4SrcToDstCapacityCompute();
|
||||
void Tiling4SortOutCompute();
|
||||
void Tiling4VMSMiddleCompute();
|
||||
void Tiling4VBSCompute();
|
||||
void ShowTilingData();
|
||||
void Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData);
|
||||
void Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData);
|
||||
virtual bool IsFullLoad() = 0;
|
||||
|
||||
|
||||
|
||||
|
||||
int64_t aivNum = 0;
|
||||
int64_t sortLoopMaxElement = 0;
|
||||
int64_t mrgSortListMaxElement = 2040;
|
||||
int64_t totalLength = 0;
|
||||
int64_t activateNum = 0;
|
||||
int64_t expertCapacity = 0;
|
||||
int64_t expertNum = 0;
|
||||
int64_t dropPadMode = 0;
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
bool expertTokensBeforeCapacityFlag = false;
|
||||
int64_t inuptXDtypeSize_ = 0;
|
||||
bool isFullLoad = false;
|
||||
|
||||
InnerMoeInitRoutingV2TilingData moeInitRoutingTilingData;
|
||||
};
|
||||
|
||||
|
||||
bool InnerMoeInitRoutingV2TilingBase::DoOpTiling() {
|
||||
sortLoopMaxElement =
|
||||
(aicoreParams_.ubSize) / (sizeof(int32_t) * NUM_TWO * NUM_FOUR) / SORT32_ALIGN_ELEMENT * SORT32_ALIGN_ELEMENT;
|
||||
isFullLoad = IsFullLoad();
|
||||
Tiling4VBSCompute();
|
||||
Tiling4VMSMiddleCompute();
|
||||
Tiling4SortOutCompute();
|
||||
Tiling4SrcToDstCompute();
|
||||
Tiling4SrcToDstCapacityCompute();
|
||||
Tiling4GatherOutCompute();
|
||||
return true;
|
||||
};
|
||||
|
||||
uint64_t InnerMoeInitRoutingV2TilingBase::GetTilingKey() const {
|
||||
if (isFullLoad) {
|
||||
return TILING_KEY_HIGH_PERFORMANCE;
|
||||
}
|
||||
if (dropPadMode == 0) {
|
||||
if (totalLength <= sortLoopMaxElement) { // 排序只用到一个核排序
|
||||
return TILING_KEY_DROPLESS_SORT_ONE_CORE;
|
||||
} else {
|
||||
return TILING_KEY_DROPLESS_SORT_MULTI_CORE;
|
||||
}
|
||||
} else {
|
||||
if (totalLength <= sortLoopMaxElement) {
|
||||
return TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE;
|
||||
} else {
|
||||
return TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE;
|
||||
}
|
||||
}
|
||||
return tilingKey_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activateNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) {
|
||||
|
||||
this->activateNum = activateNum;
|
||||
this->expertCapacity = expertCapacity;
|
||||
this->expertNum = expertNum;
|
||||
this->dropPadMode = dropPadMode;
|
||||
this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
|
||||
this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
|
||||
if (dropPadMode == 1) {
|
||||
// droppad场景下不输出expertTokensCountOrCumsum
|
||||
expertTokensCountOrCumsumFlag = 0;
|
||||
} else {
|
||||
// dropless场景下不输出expertTokensBeforeCapacity
|
||||
expertTokensBeforeCapacityFlag = false;
|
||||
}
|
||||
moeInitRoutingTilingData.cols = cols;
|
||||
moeInitRoutingTilingData.n = m;
|
||||
moeInitRoutingTilingData.k = topK;
|
||||
moeInitRoutingTilingData.expertCapacity = expertCapacity;
|
||||
moeInitRoutingTilingData.expertNum = expertNum;
|
||||
moeInitRoutingTilingData.dropPadMode = dropPadMode;
|
||||
moeInitRoutingTilingData.expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
|
||||
moeInitRoutingTilingData.expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
|
||||
totalLength = moeInitRoutingTilingData.n * moeInitRoutingTilingData.k;
|
||||
inuptXDtypeSize_ = inuptXDtypeSize;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InnerMoeInitRoutingV2TilingBase::GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) {
|
||||
aivNum = aivCoreNum;
|
||||
aicoreParams_.blockDim = aivCoreNum;
|
||||
aicoreParams_.ubSize = ubSizePlatForm;
|
||||
moeInitRoutingTilingData.coreNum = aivCoreNum;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize() {
|
||||
// 计算workspace大小
|
||||
size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; // 排序需要的空间
|
||||
size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO;
|
||||
size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t);
|
||||
workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024;
|
||||
return true;
|
||||
}
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
|
||||
tilingData->needCoreNum = 1;
|
||||
tilingData->perCoreElements = totalLength;
|
||||
tilingData->perCoreLoops = 1;
|
||||
tilingData->perCorePerLoopElements = tilingData->perCoreElements;
|
||||
tilingData->perCoreLastLoopElements = tilingData->perCoreElements;
|
||||
tilingData->lastCoreElements = tilingData->perCoreElements;
|
||||
tilingData->lastCoreLoops = 1;
|
||||
tilingData->lastCorePerLoopElements = tilingData->perCoreElements;
|
||||
tilingData->lastCoreLastLoopElements = tilingData->perCoreElements;
|
||||
}
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
|
||||
//Tiling4VBSMultiCoreCompute
|
||||
int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); // 向上取整
|
||||
needCoreNum = static_cast<int64_t>(std::pow(4, CeilLog4(needCoreNum)));
|
||||
needCoreNum = std::min(needCoreNum, aivNum); // 不能超过物理核数
|
||||
if (needCoreNum > 0) {
|
||||
int64_t perCoreElements = totalLength / needCoreNum; // 每个核处理的元素数
|
||||
int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT;
|
||||
int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements;
|
||||
int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT;
|
||||
if (lastCoreElement > alineCeilPerCoreElements) {
|
||||
perCoreElements = alineCeilPerCoreElements;
|
||||
needCoreNum = CeilDiv(totalLength, perCoreElements);
|
||||
} else {
|
||||
perCoreElements = alineFloorPerCoreElements;
|
||||
}
|
||||
tilingData->needCoreNum = needCoreNum;
|
||||
do {
|
||||
tilingData->perCoreElements = perCoreElements;
|
||||
tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // 每个核处理的loop数
|
||||
tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement);
|
||||
tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements;
|
||||
tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements;
|
||||
tilingData->lastCoreLoops = tilingData->perCoreLoops;
|
||||
int64_t tmp = CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops);
|
||||
int64_t lastCorePerLoopElements =
|
||||
CeilDiv(CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops), SORT32_ALIGN_ELEMENT) *
|
||||
SORT32_ALIGN_ELEMENT;
|
||||
tilingData->lastCorePerLoopElements = lastCorePerLoopElements;
|
||||
tilingData->lastCoreLastLoopElements = tilingData-> lastCoreElements - (tilingData->lastCoreLoops - 1) * tilingData->lastCorePerLoopElements;
|
||||
perCoreElements -= SORT32_ALIGN_ELEMENT;
|
||||
} while (tilingData->lastCoreLastLoopElements <= 0 && perCoreElements > 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() {
|
||||
auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
||||
tilingData->oneLoopMaxElements = sortLoopMaxElement;
|
||||
if (totalLength <= sortLoopMaxElement) { // 只用到一个核
|
||||
Tiling4VBSOneCoreCompute(tilingData);
|
||||
return;
|
||||
}
|
||||
Tiling4VBSMultiCoreCompute(tilingData);
|
||||
}
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() {
|
||||
auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
||||
auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp;
|
||||
if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { // 队列数小于一次vms则没有中间归并
|
||||
tilingData->needCoreNum = 0; // 需要的核数
|
||||
} else {
|
||||
int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM);
|
||||
tilingData->needCoreNum = needCoreNum; // 需要的核数
|
||||
}
|
||||
}
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4SortOutCompute() {
|
||||
auto tilingData = &moeInitRoutingTilingData.sortOutComputeParamsOp;
|
||||
tilingData->oneLoopMaxElements = mrgSortListMaxElement;
|
||||
}
|
||||
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCompute() {
|
||||
auto tilingData = &moeInitRoutingTilingData.srcToDstComputeParamsOp;
|
||||
|
||||
int64_t perLoopMaxRows = (aicoreParams_.ubSize - ASSIST_NUM * sizeof(float) - aivNum * SORT32_ALIGN_ELEMENT) /
|
||||
(SORT32_ALIGN_ELEMENT * NUM_TWO) / NUM_TWO;
|
||||
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
|
||||
if (perCoreRows <= 0) {
|
||||
tilingData->needCoreNum = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t needCoreNum = CeilDiv(totalLength, perCoreRows);
|
||||
tilingData->needCoreNum = needCoreNum;
|
||||
int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
||||
tilingData->perCoreRows = perCoreRows;
|
||||
if (perLoopMaxRows >= tilingData->perCoreRows) { // 一个loop结束
|
||||
tilingData->perCorePerLoopRows = tilingData->perCoreRows;
|
||||
tilingData->perCoreLastLoopRows = tilingData->perCoreRows;
|
||||
} else {
|
||||
tilingData->perCorePerLoopRows = perLoopMaxRows;
|
||||
tilingData->perCoreLastLoopRows = tilingData->perCoreRows - (CeilDiv(tilingData->perCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows;
|
||||
}
|
||||
tilingData->lastCoreRows = lastCoreNum;
|
||||
if (perLoopMaxRows >= tilingData->lastCoreRows) {
|
||||
tilingData->lastCorePerLoopRows = tilingData->lastCoreRows;
|
||||
tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows;
|
||||
} else {
|
||||
tilingData->lastCorePerLoopRows = perLoopMaxRows;
|
||||
tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows - (CeilDiv(tilingData->lastCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute() {
|
||||
auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp;
|
||||
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
|
||||
|
||||
if (perCoreRows <= 0) {
|
||||
tilingData->needCoreNum = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t needCoreNum = CeilDiv(totalLength, perCoreRows);
|
||||
tilingData->needCoreNum = needCoreNum;
|
||||
int64_t cols = moeInitRoutingTilingData.cols;
|
||||
tilingData->perCoreRows = perCoreRows;
|
||||
int64_t lastCoreRows = totalLength - perCoreRows * (needCoreNum - 1);
|
||||
tilingData->lastCoreRows = lastCoreRows;
|
||||
|
||||
|
||||
int64_t rowSize =
|
||||
(perCoreRows * sizeof(int32_t) * 2 + ONE_BLOCK_BYTE + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
int64_t colSize = (cols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
|
||||
if (rowSize + colSize < static_cast<int64_t>(aicoreParams_.ubSize)) {
|
||||
tilingData->perCorePerLoopRows = perCoreRows;
|
||||
tilingData->perCoreLastLoopRows = perCoreRows;
|
||||
tilingData->lastCorePerLoopRows = lastCoreRows;
|
||||
tilingData->lastCoreLastLoopRows = lastCoreRows;
|
||||
tilingData->perCoreLoops = 1;
|
||||
tilingData->lastCoreLoops = 1;
|
||||
tilingData->perLoopCols = cols;
|
||||
tilingData->lastLoopCols = cols;
|
||||
tilingData->colLoops = 1;
|
||||
|
||||
} else {
|
||||
int64_t baseMaxCols = MAX_COLS_ONE_LOOP;
|
||||
int64_t baseMaxColsSize = (baseMaxCols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
int64_t basePerLoopMaxRows = (static_cast<int64_t>(aicoreParams_.ubSize) - baseMaxColsSize - ONE_BLOCK_BYTE) /
|
||||
static_cast<int64_t>(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
if (cols < MAX_COLS_ONE_LOOP) {
|
||||
basePerLoopMaxRows = (static_cast<int64_t>(aicoreParams_.ubSize) - colSize - ONE_BLOCK_BYTE) /
|
||||
static_cast<int64_t>(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
} else if (perCoreRows < basePerLoopMaxRows) {
|
||||
baseMaxCols =
|
||||
(static_cast<int64_t>(aicoreParams_.ubSize) - rowSize) / inuptXDtypeSize_ / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
|
||||
}
|
||||
tilingData->perLoopCols = (std::min(baseMaxCols, cols));
|
||||
tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols));
|
||||
tilingData->colLoops = ((cols + baseMaxCols - 1) / baseMaxCols);
|
||||
tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows));
|
||||
tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows));
|
||||
tilingData->perCoreLoops = ((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
|
||||
tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows));
|
||||
tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows));
|
||||
tilingData->lastCoreLoops = ((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_common.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_COMMON_H
|
||||
#define INNER_MOE_V2_COMMON_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
constexpr int64_t SPLIT_N = 0;
|
||||
constexpr int64_t SPLIT_K = 1;
|
||||
constexpr float MIN_FP32 = -3.4e38;
|
||||
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
|
||||
constexpr int64_t BLOCK_BYTES = 32;
|
||||
constexpr int64_t INT32_ONE_BLOCK_NUM = 8;
|
||||
|
||||
constexpr int64_t ASSIST_NUM = 256;
|
||||
constexpr int64_t ASSIST_INDEX_NUM = 32;
|
||||
|
||||
constexpr int64_t MERGE_LIST_TWO = 2;
|
||||
constexpr int64_t MERGE_LIST_THREE = 3;
|
||||
constexpr int64_t MERGE_LIST_FOUR = 4;
|
||||
|
||||
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
|
||||
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
|
||||
|
||||
constexpr int64_t MAX_EXPERT_NUM = 5120;
|
||||
constexpr int64_t DROPLESS_MODE = 0;
|
||||
constexpr int64_t DROP_PAD_MODE = 1;
|
||||
constexpr int64_t EXERPT_TOKENS_COUNT = 2;
|
||||
constexpr int64_t EXERPT_TOKENS_CUMSUM = 1;
|
||||
constexpr int64_t EXERPT_TOKENS_NONE = 0;
|
||||
constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1;
|
||||
|
||||
const __gm__ int32_t assist[256] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,
|
||||
4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0,
|
||||
8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0,
|
||||
12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0,
|
||||
16, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0,
|
||||
20, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0,
|
||||
24, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0, 0,
|
||||
28, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
__aicore__ inline int64_t Ceil(int64_t a, int64_t b) {
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) {
|
||||
if (bytes == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
|
||||
}
|
||||
|
||||
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) {
|
||||
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(T a, T b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Max(T a, T b) {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
|
||||
template <HardEvent event>
|
||||
__aicore__ inline void SetWaitFlag(HardEvent evt) {
|
||||
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
|
||||
SetFlag<event>(eventId);
|
||||
WaitFlag<event>(eventId);
|
||||
}
|
||||
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_COMMON_H
|
||||
@@ -0,0 +1,310 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_expert_token_out.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_EXPERT_TOKEN_OUT_H
|
||||
#define INNER_MOE_V2_EXPERT_TOKEN_OUT_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
constexpr int64_t EXPERT_ID_VALUE_NUM = 2;
|
||||
|
||||
class MoeV2ExpertTokenOut {
|
||||
public:
|
||||
__aicore__ inline MoeV2ExpertTokenOut(){};
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
|
||||
GM_ADDR expandedRowIdx, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(int64_t progress);
|
||||
__aicore__ inline void Compute(int64_t progress);
|
||||
__aicore__ inline void SyncAll();
|
||||
__aicore__ inline void InitLocal();
|
||||
__aicore__ inline void GetExpertTokenCount(int32_t curExpertId);
|
||||
__aicore__ inline void CopyOutTokenGm();
|
||||
__aicore__ inline void CopyOutExpertTokensCumsum(bool isTail);
|
||||
__aicore__ inline void CopyOutExpertTokensCount(bool isTail);
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, 1> copyInQueue;
|
||||
TQue<QuePosition::VECIN, 1> expertTokenIdxCopyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> expertTokenIdxCopyOutQueue;
|
||||
|
||||
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
|
||||
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm;
|
||||
GlobalTensor<int32_t> expertIdxValueGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
|
||||
LocalTensor<int32_t> expertTokenIdxOutLocal;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
|
||||
|
||||
int64_t coreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t totalLength;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t expertNum;
|
||||
int64_t expertNumUbAlign;
|
||||
int64_t dropPadMode = 0;
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
int64_t expertTokensBeforeCapacityFlag = 0;
|
||||
|
||||
int64_t tokenCount = 0;
|
||||
int64_t expertIdx = 0;
|
||||
int32_t lastExpertId = -1;
|
||||
int32_t firstExpertId = -1;
|
||||
|
||||
int32_t expertTokenValue = 0;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::InitLocal() {
|
||||
LocalTensor<int32_t> tokenIdxLocal = expertTokenIdxCopyOutQueue.AllocTensor<int32_t>();
|
||||
Duplicate<int32_t>(tokenIdxLocal, 0, this->expertNumUbAlign);
|
||||
expertTokenIdxCopyOutQueue.EnQue<int32_t>(tokenIdxLocal);
|
||||
|
||||
// expandedRowIdx initialized to -1, which is used in the src_to_dst_with_capacity step.
|
||||
// use this step SyncAll to synchronize every core data
|
||||
if (this->dropPadMode == 0) {
|
||||
return;
|
||||
}
|
||||
LocalTensor<int32_t> outLocal = copyInQueue.AllocTensor<int32_t>();
|
||||
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
|
||||
Duplicate<int32_t>(outLocal, -1, perLoopRows);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
for (int64_t loop = 0; loop < loops; loop++) {
|
||||
int64_t copyLength = perLoopRows;
|
||||
if (loop == loops - 1) {
|
||||
copyLength = lastLoopRows;
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0,
|
||||
0};
|
||||
DataCopyPad(expandedRowIdxGm[this->blockIdx * this->srcToDstTilingData->perCoreRows + loop * perLoopRows], outLocal,
|
||||
copyParams);
|
||||
}
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
copyInQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::CopyIn(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
|
||||
DataCopy(inLocal, expandedExpertIdxGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t)));
|
||||
copyInQueue.EnQue<int32_t>(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::GetExpertTokenCount(int32_t curExpertId) {
|
||||
this->tokenCount++;
|
||||
if (this->lastExpertId < curExpertId) {
|
||||
this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount - 1);
|
||||
this->tokenCount = 1;
|
||||
this->expertIdx += (curExpertId - this->lastExpertId);
|
||||
while (curExpertId - this->firstExpertId + 1 > this->expertNumUbAlign) {
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
CopyOutExpertTokensCumsum(false);
|
||||
CopyOutExpertTokensCount(false);
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
Duplicate<int32_t>(this->expertTokenIdxOutLocal, 0, this->expertNumUbAlign);
|
||||
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
|
||||
this->firstExpertId += this->expertNumUbAlign;
|
||||
this->expertIdx = curExpertId - this->firstExpertId;
|
||||
}
|
||||
this->lastExpertId = curExpertId;
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::Compute(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
if (this->lastExpertId == -1) {
|
||||
this->lastExpertId = inLocal.GetValue(0);
|
||||
this->firstExpertId = this->lastExpertId;
|
||||
}
|
||||
for (int64_t i = 0; i < currentLoopRows; i++) {
|
||||
int32_t expertId = inLocal.GetValue(i);
|
||||
GetExpertTokenCount(expertId);
|
||||
}
|
||||
this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount);
|
||||
copyInQueue.FreeTensor(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCumsum(bool isTail) {
|
||||
if (this->dropPadMode != DROPLESS_MODE || expertTokensCountOrCumsumFlag != EXERPT_TOKENS_CUMSUM) {
|
||||
return;
|
||||
}
|
||||
#ifdef __CCE_KT_TEST__
|
||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
||||
if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
|
||||
int64_t end = this->expertNum - this->firstExpertId;
|
||||
for (int64_t i = 0; i < copyLength; i++) {
|
||||
this->expertTokenValue += this->expertTokenIdxOutLocal.GetValue(i);
|
||||
this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue);
|
||||
}
|
||||
// if the remianing UB is sufficient, use the UB space to copy
|
||||
// otherwise, copy the calculated data first, and then copy the last tokenValue to remaining expert position
|
||||
if (isTail && end <= this->expertNumUbAlign) {
|
||||
int64_t startAlign = Min(Align(copyLength, sizeof(int32_t)), end);
|
||||
for (int64_t i = copyLength; i < startAlign; i++) {
|
||||
this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue);
|
||||
}
|
||||
if (startAlign < end) {
|
||||
Duplicate<int32_t>(this->expertTokenIdxOutLocal[startAlign], this->expertTokenValue, end - startAlign);
|
||||
}
|
||||
copyLength = end;
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
|
||||
SetAtomicAdd<int32_t>();
|
||||
#ifndef __CCE_KT_TEST__
|
||||
DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
|
||||
#endif
|
||||
SetAtomicNone();
|
||||
if (isTail && end > this->expertNumUbAlign) {
|
||||
int64_t remainderLength = end - copyLength;
|
||||
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
|
||||
Duplicate<int32_t>(this->expertTokenIdxOutLocal, this->expertTokenValue, this->expertNumUbAlign);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
int64_t loopTimes = remainderLength / this->expertNumUbAlign + 1;
|
||||
for (int64_t i = 0; i < loopTimes; i++) {
|
||||
copyLength = i == loopTimes - 1 ? remainderLength - this->expertNumUbAlign * i : this->expertNumUbAlign;
|
||||
DataCopyExtParams params{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
|
||||
SetAtomicAdd<int32_t>();
|
||||
DataCopyPad(expertTokensCountOrCumsumGm[this->lastExpertId + 1 + this->expertNumUbAlign * i],
|
||||
this->expertTokenIdxOutLocal, params);
|
||||
SetAtomicNone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCount(bool isTail) {
|
||||
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
|
||||
#ifdef __CCE_KT_TEST__
|
||||
// CPU孪生调试不进行输出拷贝
|
||||
return;
|
||||
#endif
|
||||
SetAtomicAdd<int32_t>();
|
||||
if (this->dropPadMode == DROP_PAD_MODE && expertTokensBeforeCapacityFlag > EXERPT_TOKENS_NONE) {
|
||||
DataCopyPad(expertTokensBeforeCapacityGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
|
||||
}
|
||||
if (this->dropPadMode == DROPLESS_MODE && expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
|
||||
DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
|
||||
}
|
||||
SetAtomicNone();
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutTokenGm() {
|
||||
if (this->dropPadMode == DROPLESS_MODE) {
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
CopyOutExpertTokensCumsum(true);
|
||||
CopyOutExpertTokensCount(true);
|
||||
return;
|
||||
}
|
||||
this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign, this->lastExpertId);
|
||||
this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign + 1, this->tokenCount);
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(EXPERT_ID_VALUE_NUM * sizeof(int32_t)),
|
||||
0, 0, 0};
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(expertIdxValueGm[this->blockIdx * EXPERT_ID_VALUE_NUM],
|
||||
this->expertTokenIdxOutLocal[this->expertNumUbAlign], copyParams);
|
||||
CopyOutExpertTokensCount(true);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::SyncAll() {
|
||||
if (coreNum == 1) {
|
||||
return;
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
|
||||
GM_ADDR expandedRowIdx, GM_ADDR workspace,
|
||||
const TilingData* tilingData, TPipe* tPipe) {
|
||||
int64_t blockNum = GetBlockNum();
|
||||
this->pipe = tPipe;
|
||||
//this->blockIdx = GetBlockIdx();
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp);
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
|
||||
|
||||
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->srcToDstTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
|
||||
} else {
|
||||
this->coreRows = this->srcToDstTilingData->perCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
|
||||
}
|
||||
|
||||
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, Align(this->totalLength, sizeof(int32_t)));
|
||||
if (this->dropPadMode == DROPLESS_MODE && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, this->expertNum);
|
||||
}
|
||||
if (this->dropPadMode == DROP_PAD_MODE && this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
|
||||
expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity, this->expertNum);
|
||||
}
|
||||
|
||||
expandedExpertIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2,
|
||||
this->coreNum * 2);
|
||||
|
||||
this->expertNumUbAlign = Min(Align(this->expertNum, sizeof(int32_t)), MAX_EXPERT_NUM);
|
||||
pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES);
|
||||
pipe->InitBuffer(expertTokenIdxCopyInQueue, 1, this->expertNumUbAlign * sizeof(int32_t));
|
||||
pipe->InitBuffer(expertTokenIdxCopyOutQueue, 1, (this->expertNumUbAlign + EXPERT_ID_VALUE_NUM) * sizeof(int32_t));
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2ExpertTokenOut::Process() {
|
||||
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
|
||||
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
|
||||
currentLoopRows = perLoopRows;
|
||||
InitLocal();
|
||||
this->expertTokenIdxOutLocal = expertTokenIdxCopyOutQueue.DeQue<int32_t>();
|
||||
for (int64_t loop = 0; loop < loops - 1; loop++) {
|
||||
CopyIn(loop);
|
||||
Compute(loop);
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyIn(loops - 1);
|
||||
Compute(loops - 1);
|
||||
CopyOutTokenGm();
|
||||
expertTokenIdxCopyOutQueue.FreeTensor(this->expertTokenIdxOutLocal);
|
||||
}
|
||||
this->SyncAll();
|
||||
}
|
||||
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_EXPERT_TOKEN_OUT_H
|
||||
@@ -0,0 +1,468 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_v2_fullload_dynamic_quant.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H
|
||||
#define MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H
|
||||
|
||||
#include "moe_v2_mrgsort.h"
|
||||
#include "moe_v2_sort_base.h"
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T>
|
||||
class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2FullLoadDynamicQuant(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale,
|
||||
GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void SortCompute();
|
||||
__aicore__ inline void CopyOutIdx();
|
||||
__aicore__ inline void CopyOutEmpty();
|
||||
__aicore__ inline void CopyOutXQuant1H();
|
||||
__aicore__ inline void CopyOutXQuantEH();
|
||||
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
|
||||
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
|
||||
|
||||
private:
|
||||
int64_t sortNum_;
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_;
|
||||
int64_t blockIdx_;
|
||||
int64_t needCoreNum_;
|
||||
int64_t coreRows_;
|
||||
int64_t perCoreRows_;
|
||||
int64_t k_;
|
||||
int64_t n_;
|
||||
int64_t cols_;
|
||||
int64_t activateRows_;
|
||||
int64_t expertNum;
|
||||
int64_t expertCapacity;
|
||||
int64_t smoothType;
|
||||
int64_t colsAlign;
|
||||
|
||||
TQue<QuePosition::VECIN, 1> xCopyInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue_;
|
||||
TQue<QuePosition::VECIN, 1> smoothInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> calcQueue;
|
||||
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
|
||||
|
||||
GlobalTensor<T> xGm_;
|
||||
GlobalTensor<int32_t> expertIdxGm_;
|
||||
GlobalTensor<float> quantSmoothGm;
|
||||
GlobalTensor<float> dynamicQuantScaleGm;
|
||||
|
||||
GlobalTensor<int8_t> expandedXGm_;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm_;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm_;
|
||||
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
|
||||
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
|
||||
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
int64_t expertTokensBeforeCapacityFlag = 0;
|
||||
int64_t dropPadMode = 0;
|
||||
|
||||
LocalTensor<uint32_t> expandDstToSrcRowLocal;
|
||||
LocalTensor<int32_t> expandedExpertIdxLocal;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyIn() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
|
||||
0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
|
||||
sortDataCopyInQueue.EnQue(inLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::SortCompute() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
|
||||
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
|
||||
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
LocalTensor<float> concatLocal;
|
||||
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
|
||||
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast<uint32_t>();
|
||||
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor<float>();
|
||||
expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor<uint32_t>();
|
||||
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
|
||||
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
|
||||
this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
|
||||
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
|
||||
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedExpertIdxCopyOutQueue_.EnQue<int32_t>(expandedExpertIdxLocalInt32);
|
||||
|
||||
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor<uint32_t>();
|
||||
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
|
||||
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedRowIdxCopyOutQueue_.EnQue<uint32_t>(expandedRowIdx);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutIdx() {
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = this->totalLength * sizeof(int32_t);
|
||||
DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams);
|
||||
expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::ComputeExpertTokenCountOrCumsum() {
|
||||
expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue_.AllocTensor<int32_t>();
|
||||
|
||||
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
|
||||
Duplicate(expertTokensCount, 0, expertNumAlign);
|
||||
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
|
||||
|
||||
int32_t lastExpertId = expandedExpertIdxLocal.GetValue(0);
|
||||
int64_t tokenCount = 0;
|
||||
int64_t lastExpertCount = 0;
|
||||
for (int64_t i = 0; i < this->totalLength; i++) {
|
||||
int32_t curExpertId = expandedExpertIdxLocal.GetValue(i);
|
||||
tokenCount++;
|
||||
while (lastExpertId < curExpertId) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
|
||||
tokenCount = 1;
|
||||
}
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
|
||||
lastExpertId++;
|
||||
while (lastExpertId < this->expertNum) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
|
||||
0};
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
|
||||
}
|
||||
expertTokensCopyOutQueue_.FreeTensor(expertTokensCount);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutEmpty() {
|
||||
expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>& smoothLocal) {
|
||||
LocalTensor<float> inLocal = xCopyInQueue_.DeQue<float>();
|
||||
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(inLocal, inLocal.ReinterpretCast<T>()[colsAlign], RoundMode::CAST_NONE, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
if (smoothType != 0) {
|
||||
Mul(inLocal, inLocal, smoothLocal, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
Abs(tempLocal, inLocal, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
|
||||
|
||||
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
|
||||
Duplicate<float>(tempLocal, maxValue, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Div(tempLocal, inLocal, tempLocal, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols_);
|
||||
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
inputXOutQueue.EnQue(outLocal);
|
||||
scaleOutQueue.EnQue(dynamicQuantLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
|
||||
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
|
||||
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
|
||||
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
|
||||
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
|
||||
int64_t startXRow = curRowsStart / this->k_;
|
||||
int64_t endXRow = curRowsEnd / this->k_;
|
||||
|
||||
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
LocalTensor<float> smoothLocal;
|
||||
if (smoothType == 1) {
|
||||
smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
DataCopyPad(smoothLocal, quantSmoothGm, smoothCopyParams, {false, 0, 0, 0});
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
}
|
||||
for (int64_t row = startXRow; row <= endXRow; row++) {
|
||||
LocalTensor<T> xLocal = xCopyInQueue_.AllocTensor<T>();
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(xLocal, xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(xLocal[colsAlign], xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
xCopyInQueue_.EnQue<T>(xLocal);
|
||||
Compute(smoothLocal);
|
||||
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) {
|
||||
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
|
||||
curRowsStart++;
|
||||
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) {
|
||||
continue;
|
||||
}
|
||||
DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams);
|
||||
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
}
|
||||
|
||||
xCopyInQueue_.FreeTensor(xLocal);
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
if (smoothType == 1) {
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
}
|
||||
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuantEH() {
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
|
||||
|
||||
Muls(expandDstToSrcRowLocal.ReinterpretCast<float>(), expandDstToSrcRowLocal.ReinterpretCast<float>(), (float)-1,
|
||||
this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast<int32_t>();
|
||||
Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast<float>(), RoundMode::CAST_ROUND, this->totalLength);
|
||||
|
||||
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
|
||||
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
|
||||
|
||||
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
for (int64_t row = curRowsStart; row <= curRowsEnd; row++) {
|
||||
if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) {
|
||||
break;
|
||||
}
|
||||
int32_t srcIdx = sortedRowIdx.GetValue(row);
|
||||
int32_t expertIdx = expandedExpertIdxLocal.GetValue(row);
|
||||
|
||||
LocalTensor<T> inLocal = xCopyInQueue_.AllocTensor<T>();
|
||||
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
|
||||
}
|
||||
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0});
|
||||
xCopyInQueue_.EnQue<T>(inLocal);
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
|
||||
Compute(smoothLocal);
|
||||
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams);
|
||||
|
||||
xCopyInQueue_.FreeTensor(inLocal);
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
|
||||
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
|
||||
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
|
||||
GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale,
|
||||
GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp);
|
||||
//this->blockIdx_ = GetBlockIdx();
|
||||
this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->k_ = tilingData->k;
|
||||
this->n_ = tilingData->n;
|
||||
this->cols_ = tilingData->cols;
|
||||
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
|
||||
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
|
||||
this->activateRows_ = this->gatherOutTilingData_->activateRows;
|
||||
if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) {
|
||||
this->coreRows_ = this->gatherOutTilingData_->lastCoreRows;
|
||||
} else {
|
||||
this->coreRows_ = this->gatherOutTilingData_->perCoreRows;
|
||||
}
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
|
||||
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
|
||||
this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->smoothType = tilingData->smoothType;
|
||||
this->colsAlign = Align(this->cols_, sizeof(T));
|
||||
this->pipe = tPipe;
|
||||
|
||||
xGm_.SetGlobalBuffer((__gm__ T*)x);
|
||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
|
||||
|
||||
expandedXGm_.SetGlobalBuffer((__gm__ int8_t*)expandedX);
|
||||
expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
// dropless
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
|
||||
Align(this->expertNum, sizeof(int32_t)));
|
||||
}
|
||||
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
|
||||
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
|
||||
|
||||
int64_t kvFactor = 2;
|
||||
int64_t buffSize = this->sortNum_ * sizeof(int32_t);
|
||||
|
||||
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
|
||||
int64_t startXRow = curRowsStart / this->k_;
|
||||
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
|
||||
|
||||
pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
|
||||
pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
|
||||
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
|
||||
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
|
||||
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
pipe->InitBuffer(xCopyInQueue_, 1, AlignBytes(this->cols_, sizeof(float)));
|
||||
} else {
|
||||
pipe->InitBuffer(xCopyInQueue_, 1, 2 * AlignBytes(this->cols_, sizeof(T)));
|
||||
}
|
||||
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float)));
|
||||
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float)));
|
||||
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t)));
|
||||
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Process() {
|
||||
if (this->blockIdx_ < this->needCoreNum_) {
|
||||
CopyIn();
|
||||
SortCompute();
|
||||
if (this->blockIdx_ == 0) {
|
||||
CopyOutIdx();
|
||||
}
|
||||
if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
ComputeExpertTokenCountOrCumsum();
|
||||
} else {
|
||||
CopyOutEmpty();
|
||||
}
|
||||
if (smoothType == 2) {
|
||||
CopyOutXQuantEH();
|
||||
} else {
|
||||
CopyOutXQuant1H();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // MOE_V2_DYNAMIC_QUANT_FULL_LOAD_H
|
||||
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_v2_fullload_quant.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_FULL_LOAD_QUANT_H
|
||||
#define MOE_V2_FULL_LOAD_QUANT_H
|
||||
|
||||
#include "moe_v2_fullload_quant_base.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T>
|
||||
class MoeV2FullLoadQuant : public MoeV2FullLoadQuantBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2FullLoadQuant(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX,
|
||||
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void Compute(int64_t xLocalLength);
|
||||
__aicore__ inline void CopyOutX();
|
||||
|
||||
private:
|
||||
TQue<QuePosition::VECOUT, 1> floatQueue;
|
||||
TQue<QuePosition::VECOUT, 1> halfQueue;
|
||||
TQue<QuePosition::VECOUT, 1> inputXCopyOutQueue;
|
||||
|
||||
GlobalTensor<T> xGm;
|
||||
GlobalTensor<float> scaleGm;
|
||||
GlobalTensor<float> offsetGm;
|
||||
|
||||
float scale;
|
||||
float offset;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadQuant<T>::Compute(int64_t xLocalLength) {
|
||||
LocalTensor<T> inLocal = xCopyInQueue.DeQue<T>();
|
||||
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> floatLocal = floatQueue.AllocTensor<float>();
|
||||
LocalTensor<half> halfLocal = halfQueue.AllocTensor<half>();
|
||||
|
||||
uint32_t elements = Align(this->cols, sizeof(int8_t)) * xLocalLength;
|
||||
if constexpr (IsSameType<T, bfloat16_t>::value) {
|
||||
Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> intLocal = floatLocal.ReinterpretCast<int32_t>();
|
||||
Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
} else if constexpr (IsSameType<T, float>::value) {
|
||||
Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
} else {
|
||||
Muls(inLocal, inLocal, static_cast<T>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(inLocal, inLocal, static_cast<T>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements);
|
||||
}
|
||||
inputXCopyOutQueue.EnQue(outLocal);
|
||||
xCopyInQueue.FreeTensor(inLocal);
|
||||
floatQueue.FreeTensor(floatLocal);
|
||||
halfQueue.FreeTensor(halfLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadQuant<T>::CopyOutX() {
|
||||
LocalTensor<T> xLocal = xCopyInQueue.AllocTensor<T>();
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue<int32_t>();
|
||||
int64_t inFactor = Align(this->cols, sizeof(int8_t));
|
||||
int64_t curRowsStart = this->blockIdx * this->perCoreRows;
|
||||
int64_t startXRow = curRowsStart / this->k;
|
||||
int64_t endXRow = (curRowsStart + this->coreRows - 1) / this->k;
|
||||
|
||||
uint32_t dstStride = (inFactor * sizeof(T) - AlignBytes(this->cols, sizeof(T))) / BLOCK_BYTES;
|
||||
DataCopyExtParams dataXCopyParams{static_cast<uint16_t>(endXRow - startXRow + 1),
|
||||
static_cast<uint32_t>(this->cols * sizeof(T)), 0, dstStride, 0};
|
||||
DataCopyPadExtParams<T> dataXCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(xLocal, xGm[startXRow * this->cols], dataXCopyParams, dataXCopyPadParams);
|
||||
xCopyInQueue.EnQue(xLocal);
|
||||
Compute(endXRow - startXRow + 1);
|
||||
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.DeQue<int8_t>();
|
||||
int64_t k = 0;
|
||||
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
|
||||
for (int64_t i = startXRow; i <= endXRow; i++) {
|
||||
for (; k < this->perCoreRows && curRowsStart / this->k == i; curRowsStart++, k++) {
|
||||
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
|
||||
if (outIndex < this->activateRows) {
|
||||
DataCopyPad(expandedXGm[outIndex * this->cols], outLocal[(i - startXRow) * inFactor], intriParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
expandedRowIdxCopyOutQueue.FreeTensor(expandedRowIdx);
|
||||
inputXCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset,
|
||||
GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) {
|
||||
this->InitBase(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x);
|
||||
scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1);
|
||||
offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1);
|
||||
this->scale = scaleGm.GetValue(0);
|
||||
this->offset = offsetGm.GetValue(0);
|
||||
|
||||
int64_t curRowsStart = this->blockIdx * this->perCoreRows;
|
||||
int64_t rowLength = (curRowsStart + this->coreRows - 1) / this->k - curRowsStart / this->k + 1;
|
||||
int64_t xAlignedCount = Align(this->cols, sizeof(int8_t));
|
||||
pipe->InitBuffer(xCopyInQueue, bufferNum, xAlignedCount * sizeof(T) * rowLength);
|
||||
pipe->InitBuffer(inputXCopyOutQueue, 1, xAlignedCount * sizeof(int8_t) * rowLength);
|
||||
pipe->InitBuffer(floatQueue, 1, xAlignedCount * sizeof(float) * rowLength);
|
||||
pipe->InitBuffer(halfQueue, 1, xAlignedCount * sizeof(half) * rowLength);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoadQuant<T>::Process() {
|
||||
if (this->blockIdx < this->needCoreNum) {
|
||||
this->ProcessBase();
|
||||
CopyOutX();
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif
|
||||
@@ -0,0 +1,279 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_v2_fullload_quant_base.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_FULL_LOAD_QUANT_BASE_H
|
||||
#define MOE_V2_FULL_LOAD_QUANT_BASE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2FullLoadQuantBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2FullLoadQuantBase(){};
|
||||
|
||||
protected:
|
||||
__aicore__ inline void InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void ProcessBase();
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void SortCompute();
|
||||
__aicore__ inline void CopyOutIdx();
|
||||
__aicore__ inline void CopyOutEmpty();
|
||||
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
|
||||
|
||||
protected:
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
|
||||
|
||||
TPipe* pipe;
|
||||
int64_t tileLength;
|
||||
int64_t bufferNum = 1;
|
||||
int64_t totalLength;
|
||||
int64_t coreNum;
|
||||
int64_t sortNum;
|
||||
int64_t blockIdx;
|
||||
int64_t needCoreNum;
|
||||
int64_t coreRows;
|
||||
int64_t perCoreRows;
|
||||
int64_t k;
|
||||
int64_t n;
|
||||
int64_t cols;
|
||||
int64_t activateRows;
|
||||
int64_t expertNum;
|
||||
int64_t expertCapacity;
|
||||
|
||||
TQue<QuePosition::VECIN, 1> sortDataCopyInQueue;
|
||||
TBuf<TPosition::VECCALC> tempBuffer;
|
||||
TBuf<TPosition::VECCALC> sortedBuffer;
|
||||
TQue<QuePosition::VECIN, 1> xCopyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue;
|
||||
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue;
|
||||
|
||||
GlobalTensor<int32_t> expertIdxGm;
|
||||
GlobalTensor<int8_t> expandedXGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm;
|
||||
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
|
||||
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
|
||||
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
int64_t expertTokensBeforeCapacityFlag = 0;
|
||||
int64_t dropPadMode = 0;
|
||||
static constexpr int64_t DST_BLK_STRIDE = 1;
|
||||
static constexpr int64_t DST_REP_STRIDE = 8;
|
||||
static constexpr int64_t FOUR_BLOCK_BYTES = 128;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::CopyIn() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
|
||||
0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum], 0, 1, this->totalLength);
|
||||
sortDataCopyInQueue.EnQue(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::SortCompute() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
|
||||
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
|
||||
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
LocalTensor<float> concatLocal;
|
||||
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum));
|
||||
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum].template ReinterpretCast<uint32_t>();
|
||||
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum));
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue.AllocTensor<float>();
|
||||
LocalTensor<uint32_t> expandDstToSrcRowLocal = expandDstToSrcRowQueue.AllocTensor<uint32_t>();
|
||||
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
|
||||
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
|
||||
this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
|
||||
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
|
||||
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedExpertIdxCopyOutQueue.EnQue<int32_t>(expandedExpertIdxLocalInt32);
|
||||
|
||||
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.AllocTensor<uint32_t>();
|
||||
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
|
||||
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum], 0, 1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedRowIdxCopyOutQueue.EnQue<uint32_t>(expandedRowIdx);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
|
||||
expandDstToSrcRowQueue.FreeTensor(expandDstToSrcRowLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutIdx() {
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue<int32_t>();
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = this->totalLength * sizeof(int32_t);
|
||||
DataCopyPad(expandedRowIdxGm, expandedRowIdx, intriParams);
|
||||
expandedRowIdxCopyOutQueue.EnQue(expandedRowIdx);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::ComputeExpertTokenCountOrCumsum() {
|
||||
LocalTensor<int32_t> expandedExpertIdx = expandedExpertIdxCopyOutQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue.AllocTensor<int32_t>();
|
||||
|
||||
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
|
||||
Duplicate(expertTokensCount, 0, expertNumAlign);
|
||||
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
|
||||
|
||||
int32_t lastExpertId = expandedExpertIdx.GetValue(0);
|
||||
int64_t tokenCount = 0;
|
||||
int64_t lastExpertCount = 0;
|
||||
for (int64_t i = 0; i < this->totalLength; i++) {
|
||||
int32_t curExpertId = expandedExpertIdx.GetValue(i);
|
||||
tokenCount++;
|
||||
while (lastExpertId < curExpertId) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
|
||||
tokenCount = 1;
|
||||
}
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
|
||||
lastExpertId++;
|
||||
while (lastExpertId < this->expertNum) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
|
||||
0};
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
|
||||
}
|
||||
expertTokensCopyOutQueue.FreeTensor(expertTokensCount);
|
||||
expandedExpertIdxCopyOutQueue.FreeTensor(expandedExpertIdx);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutEmpty() {
|
||||
LocalTensor<int32_t> outLocal = expandedExpertIdxCopyOutQueue.DeQue<int32_t>();
|
||||
expandedExpertIdxCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
|
||||
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
|
||||
GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->k = tilingData->k;
|
||||
this->n = tilingData->n;
|
||||
this->cols = tilingData->cols;
|
||||
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
|
||||
this->perCoreRows = this->gatherOutTilingData->perCoreRows;
|
||||
this->activateRows = this->gatherOutTilingData->activateRows;
|
||||
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->gatherOutTilingData->lastCoreRows;
|
||||
} else {
|
||||
this->coreRows = this->gatherOutTilingData->perCoreRows;
|
||||
}
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
|
||||
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
|
||||
this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->pipe = tPipe;
|
||||
|
||||
expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
|
||||
|
||||
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
|
||||
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
// dropless
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
|
||||
Align(this->expertNum, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
int64_t kvFactor = 2;
|
||||
int64_t buffSize = this->sortNum * sizeof(int32_t);
|
||||
|
||||
pipe->InitBuffer(expandedRowIdxCopyOutQueue, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expandedExpertIdxCopyOutQueue, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expertTokensCopyOutQueue, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
|
||||
pipe->InitBuffer(expandDstToSrcRowQueue, bufferNum, buffSize);
|
||||
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
|
||||
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
|
||||
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2FullLoadQuantBase::ProcessBase() {
|
||||
if (this->blockIdx < this->needCoreNum) {
|
||||
CopyIn();
|
||||
SortCompute();
|
||||
if (this->blockIdx == 0) {
|
||||
CopyOutIdx();
|
||||
}
|
||||
if (this->blockIdx == this->needCoreNum - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
ComputeExpertTokenCountOrCumsum();
|
||||
} else {
|
||||
CopyOutEmpty();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // MOE_V2_FULL_LOAD_QUANT_BASE_H
|
||||
@@ -0,0 +1,568 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_gather_dynamic_quant.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_GATHER_DYNAMIC_QUANT_H
|
||||
#define MOE_V2_GATHER_DYNAMIC_QUANT_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T>
|
||||
class MoeV2GatherDynamicQuant {
|
||||
public:
|
||||
__aicore__ inline MoeV2GatherDynamicQuant(){};
|
||||
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
|
||||
GM_ADDR dynamicQuantScale, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInExpandedRowIdx(int64_t progress);
|
||||
__aicore__ inline void CopyInExpandedExpertIdx(int64_t progress);
|
||||
__aicore__ inline void CopyOutXQuant1H(int64_t progress);
|
||||
__aicore__ inline void CopyOutXQuantEH(int64_t progress);
|
||||
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
|
||||
__aicore__ inline void CopyOutPartialXQuantEH(int64_t progress);
|
||||
__aicore__ inline void CopyOutPartialXQuant1H(int64_t progress);
|
||||
__aicore__ inline float ComputeMax(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal,
|
||||
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx,
|
||||
int64_t j);
|
||||
__aicore__ inline void ComputeScale(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal, float scaleTemp,
|
||||
int64_t dstIndex, int64_t j);
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inputXInQueue;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> smoothInQueue;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> expandRowIdxInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> calcQueue;
|
||||
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
|
||||
|
||||
GlobalTensor<T> inputXGm;
|
||||
GlobalTensor<int8_t> expandedXGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
GlobalTensor<float> quantSmoothGm;
|
||||
GlobalTensor<float> dynamicQuantScaleGm;
|
||||
GlobalTensor<float> quantSrcGm;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm;
|
||||
GlobalTensor<int32_t> sortedRowIdxGm;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
|
||||
|
||||
int64_t needCoreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t cols;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
int64_t totalLength;
|
||||
int64_t activateRows;
|
||||
int64_t currentLoopRows;
|
||||
int64_t currentLoopRowsAlign;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t rowLoops;
|
||||
int64_t colsTileLength;
|
||||
int64_t perLoopCols;
|
||||
int64_t perLoopColsAlign;
|
||||
int64_t lastLoopCols;
|
||||
int64_t colLoops;
|
||||
int64_t dropPadMode;
|
||||
int64_t smoothType;
|
||||
|
||||
int64_t indicesOffset;
|
||||
int64_t inputOffset;
|
||||
int64_t outOffset;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyInExpandedRowIdx(int64_t progress) {
|
||||
this->indicesOffset = progress * this->perLoopRows;
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
|
||||
expandRowIdxInQueue.EnQue<int32_t>(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyInExpandedExpertIdx(int64_t progress) {
|
||||
this->indicesOffset = progress * this->perLoopRows;
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(indicesLocal, sortedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
|
||||
DataCopyPad(indicesLocal[currentLoopRowsAlign], expandedExpertIdxGm[indicesOffset], dataCopyParams,
|
||||
dataCopyPadParams);
|
||||
expandRowIdxInQueue.EnQue<int32_t>(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& smoothLocal) {
|
||||
LocalTensor<float> inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
if (smoothType != 0) {
|
||||
Mul(inLocal, inLocal, smoothLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
Abs(tempLocal, inLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
|
||||
|
||||
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
|
||||
Duplicate<float>(tempLocal, maxValue, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Div(tempLocal, inLocal, tempLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols);
|
||||
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
inputXOutQueue.EnQue(outLocal);
|
||||
scaleOutQueue.EnQue(dynamicQuantLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
|
||||
|
||||
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
int64_t curLoopRow = 0;
|
||||
int64_t currentLoopStartRow = initialRow / this->k;
|
||||
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
|
||||
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
|
||||
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
|
||||
|
||||
LocalTensor<float> smoothLocal;
|
||||
if (smoothType == 1) {
|
||||
smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0});
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
}
|
||||
|
||||
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
|
||||
LocalTensor<T> inLocal = inputXInQueue.AllocTensor<T>();
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal, inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal[perLoopColsAlign], inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
inputXInQueue.EnQue<T>(inLocal);
|
||||
|
||||
// 计算quant
|
||||
Compute(smoothLocal);
|
||||
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
|
||||
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
|
||||
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
|
||||
curLoopRow++;
|
||||
initialRow++;
|
||||
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
|
||||
continue;
|
||||
}
|
||||
DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams);
|
||||
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
}
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
if (smoothType == 1) {
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
}
|
||||
expandRowIdxInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuantEH(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
|
||||
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
int32_t lastExpertIdx = -1;
|
||||
LocalTensor<T> inLocal = inputXInQueue.AllocTensor<T>();
|
||||
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
for (int64_t i = 0; i < this->currentLoopRows; i++) {
|
||||
int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) {
|
||||
break;
|
||||
}
|
||||
int32_t srcIdx = indicesLocal.GetValue(i);
|
||||
int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i);
|
||||
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
|
||||
}
|
||||
inputXInQueue.EnQue<T>(inLocal);
|
||||
|
||||
if (expertIdx != lastExpertIdx) {
|
||||
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0});
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
lastExpertIdx = expertIdx;
|
||||
}
|
||||
|
||||
Compute(smoothLocal);
|
||||
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
DataCopyPad(expandedXGm[(rowOffset + i) * this->cols], outLocal, copyOutParams);
|
||||
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
expandRowIdxInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline float MoeV2GatherDynamicQuant<T>::ComputeMax(LocalTensor<float>& inLocal,
|
||||
LocalTensor<float>& tempLocal,
|
||||
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx,
|
||||
int32_t expertIdx, int64_t j) {
|
||||
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
|
||||
DataCopyExtParams intriParamsT{1, static_cast<uint32_t>(colsTileLength * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams intriParamsFp32{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal.ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols],
|
||||
intriParamsT, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
inputXInQueue.EnQue<float>(inLocal);
|
||||
inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
if (smoothType != 0) {
|
||||
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32,
|
||||
{false, 0, 0, 0});
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
}
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
if (smoothType != 0) {
|
||||
Mul(inLocal, inLocal, smoothLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
Abs(tempLocal, inLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength);
|
||||
|
||||
DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32);
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
|
||||
return dynamicQuantLocal.GetValue(8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::ComputeScale(LocalTensor<float>& inLocal,
|
||||
LocalTensor<float>& tempLocal, float scaleTemp,
|
||||
int64_t dstIndex, int64_t j) {
|
||||
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(colsTileLength * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
|
||||
|
||||
DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0});
|
||||
inputXInQueue.EnQue<float>(inLocal);
|
||||
inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
Duplicate<float>(tempLocal, scaleTemp, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Div(tempLocal, inLocal, tempLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, colsTileLength);
|
||||
|
||||
inputXOutQueue.EnQue(outLocal);
|
||||
outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams);
|
||||
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutPartialXQuantEH(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
|
||||
for (int64_t i = 0; i < this->currentLoopRows; i++) {
|
||||
int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) {
|
||||
break;
|
||||
}
|
||||
int32_t srcIdx = indicesLocal.GetValue(i);
|
||||
int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i);
|
||||
|
||||
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
uint32_t tmp = 0xFF7FFFFF;
|
||||
float reduceMax = *((float*)&tmp);
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j);
|
||||
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
|
||||
}
|
||||
|
||||
float scaleTemp = reduceMax / 127.0f;
|
||||
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
|
||||
scaleOutQueue.EnQue(quantScaleLocal);
|
||||
quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
|
||||
DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
|
||||
ComputeScale(inLocal, tempLocal, scaleTemp, rowOffset + i, j);
|
||||
}
|
||||
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
expandRowIdxInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutPartialXQuant1H(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
|
||||
|
||||
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
int64_t curLoopRow = 0;
|
||||
|
||||
int64_t currentLoopStartRow = initialRow / this->k;
|
||||
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
|
||||
|
||||
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
|
||||
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
uint32_t tmp = 0xFF7FFFFF;
|
||||
float reduceMax = *((float*)&tmp);
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
|
||||
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, row, 0, j);
|
||||
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
|
||||
}
|
||||
|
||||
float scaleTemp = reduceMax / 127.0f;
|
||||
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
|
||||
scaleOutQueue.EnQue(quantScaleLocal);
|
||||
quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
|
||||
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
|
||||
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
|
||||
curLoopRow++;
|
||||
initialRow++;
|
||||
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
|
||||
continue;
|
||||
}
|
||||
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
|
||||
ComputeScale(inLocal, tempLocal, scaleTemp, outIndex, j);
|
||||
}
|
||||
}
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
expandRowIdxInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expandedX, GM_ADDR dynamicQuantScale, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
|
||||
|
||||
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
|
||||
this->activateRows = this->gatherOutTilingData->activateRows;
|
||||
this->cols = tilingData->cols;
|
||||
this->n = tilingData->n;
|
||||
this->k = tilingData->k;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
this->smoothType = tilingData->smoothType;
|
||||
|
||||
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->gatherOutTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
|
||||
} else {
|
||||
this->coreRows = this->gatherOutTilingData->perCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
|
||||
}
|
||||
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
|
||||
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
|
||||
this->colLoops = this->gatherOutTilingData->colLoops;
|
||||
this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T));
|
||||
|
||||
inputXGm.SetGlobalBuffer((__gm__ T*)inputX);
|
||||
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
|
||||
|
||||
expandedRowIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
|
||||
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
|
||||
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
|
||||
|
||||
expandedExpertIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + this->blockIdx * this->gatherOutTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
sortedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) +
|
||||
this->blockIdx * this->gatherOutTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
if (this->cols > 1) {
|
||||
quantSrcGm.SetGlobalBuffer(
|
||||
(__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2 + this->blockIdx * this->cols,
|
||||
this->cols * sizeof(float));
|
||||
}
|
||||
|
||||
this->currentLoopRowsAlign = Align(this->perLoopRows, sizeof(int32_t));
|
||||
|
||||
int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T));
|
||||
perLoopColsAlignBytes =
|
||||
Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES));
|
||||
|
||||
pipe->InitBuffer(expandRowIdxInQueue, BUFFER_NUM, 2 * AlignBytes(this->perLoopRows, sizeof(int32_t)));
|
||||
pipe->InitBuffer(inputXInQueue, BUFFER_NUM, perLoopColsAlignBytes);
|
||||
pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float)));
|
||||
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
|
||||
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
|
||||
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
|
||||
if (this->blockIdx < this->needCoreNum) {
|
||||
currentLoopRows = perLoopRows;
|
||||
if (colLoops > 1) { // 一行无法全载,需要workspace
|
||||
if (smoothType == 2) {
|
||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||
CopyInExpandedExpertIdx(loop);
|
||||
CopyOutPartialXQuantEH(loop);
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyInExpandedExpertIdx(this->rowLoops - 1);
|
||||
CopyOutPartialXQuantEH(this->rowLoops - 1);
|
||||
} else {
|
||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||
CopyInExpandedRowIdx(loop);
|
||||
CopyOutPartialXQuant1H(loop);
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyInExpandedRowIdx(this->rowLoops - 1);
|
||||
CopyOutPartialXQuant1H(this->rowLoops - 1);
|
||||
}
|
||||
} else { // 一行可以全载
|
||||
if (smoothType == 2) {
|
||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||
CopyInExpandedExpertIdx(loop);
|
||||
CopyOutXQuantEH(loop);
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyInExpandedExpertIdx(this->rowLoops - 1);
|
||||
CopyOutXQuantEH(this->rowLoops - 1);
|
||||
} else {
|
||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||
CopyInExpandedRowIdx(loop);
|
||||
CopyOutXQuant1H(loop);
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyInExpandedRowIdx(this->rowLoops - 1);
|
||||
CopyOutXQuant1H(this->rowLoops - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // MOE_V2_GATHER_DYNAMIC_QUANT_H
|
||||
@@ -0,0 +1,181 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_gather_out.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_GATHER_OUT_H
|
||||
#define INNER_MOE_V2_GATHER_OUT_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
constexpr int64_t BUFFER_NUM = 2;
|
||||
|
||||
template <typename T>
|
||||
class MoeV2GatherOut {
|
||||
public:
|
||||
__aicore__ inline MoeV2GatherOut(){};
|
||||
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace,
|
||||
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInIndices(int64_t progress);
|
||||
__aicore__ inline void CopyOut(int64_t progress);
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, BUFFER_NUM> inputActivationsCopyInQueue;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> expandDstToSrcRowCopyInQueue;
|
||||
|
||||
GlobalTensor<T> inputXGm;
|
||||
GlobalTensor<T> expandedXGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
|
||||
|
||||
int64_t needCoreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t cols;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
int64_t activateRows;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t rowLoops;
|
||||
int64_t colsTileLength;
|
||||
int64_t perLoopCols;
|
||||
int64_t lastLoopCols;
|
||||
int64_t colLoops;
|
||||
int64_t dropPadMode;
|
||||
|
||||
int64_t indicesOffset;
|
||||
int64_t inputOffset;
|
||||
int64_t outOffset;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherOut<T>::CopyInIndices(int64_t progress) {
|
||||
this->indicesOffset = progress * this->perLoopRows;
|
||||
LocalTensor<int32_t> indicesLocal = expandDstToSrcRowCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
|
||||
|
||||
expandDstToSrcRowCopyInQueue.EnQue<int32_t>(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherOut<T>::CopyOut(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandDstToSrcRowCopyInQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
colsTileLength = this->perLoopCols;
|
||||
for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) {
|
||||
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
int64_t curLoopRow = 0;
|
||||
if (colsLoop == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
int64_t currentLoopStartRow = initialRow / this->k;
|
||||
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
|
||||
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
|
||||
LocalTensor<T> inLocal = inputActivationsCopyInQueue.AllocTensor<T>();
|
||||
// input row position
|
||||
inputOffset = row * this->cols + colsLoop * this->perLoopCols;
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams<T> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams);
|
||||
SetWaitFlag<HardEvent::MTE2_MTE3>(HardEvent::MTE2_MTE3);
|
||||
|
||||
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
|
||||
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
|
||||
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
|
||||
curLoopRow++;
|
||||
initialRow++;
|
||||
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
|
||||
continue;
|
||||
}
|
||||
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
|
||||
#ifdef __CCE_KT_TEST__
|
||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
||||
if (outOffset > expandedXGm.GetSize()) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
DataCopyPad(expandedXGm[outOffset], inLocal, intriParams);
|
||||
}
|
||||
inputActivationsCopyInQueue.FreeTensor(inLocal);
|
||||
}
|
||||
}
|
||||
expandDstToSrcRowCopyInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherOut<T>::Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
|
||||
GM_ADDR workspace, const InnerMoeInitRoutingV2TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
|
||||
|
||||
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
|
||||
this->activateRows = this->gatherOutTilingData->activateRows;
|
||||
this->cols = tilingData->cols;
|
||||
this->n = tilingData->n;
|
||||
this->k = tilingData->k;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
|
||||
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->gatherOutTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
|
||||
} else {
|
||||
this->coreRows = this->gatherOutTilingData->perCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
|
||||
}
|
||||
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
|
||||
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
|
||||
this->colLoops = this->gatherOutTilingData->colLoops;
|
||||
|
||||
inputXGm.SetGlobalBuffer((__gm__ T*)inputX, this->coreRows * this->cols);
|
||||
expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, tilingData->n * tilingData->k * this->cols);
|
||||
expandedRowIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
|
||||
pipe->InitBuffer(inputActivationsCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T)));
|
||||
pipe->InitBuffer(expandDstToSrcRowCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherOut<T>::Process() {
|
||||
if (this->blockIdx < this->needCoreNum) {
|
||||
currentLoopRows = perLoopRows;
|
||||
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
|
||||
if (loop == this->rowLoops - 1) {
|
||||
currentLoopRows = lastLoopRows;
|
||||
}
|
||||
CopyInIndices(loop);
|
||||
CopyOut(loop);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_GATHER_OUT_H
|
||||
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_gather_quant.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_GATHER_QUANT_H
|
||||
#define MOE_V2_GATHER_QUANT_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
constexpr int64_t BUFFER_NUM = 2;
|
||||
|
||||
template <typename T>
|
||||
class MoeV2GatherQuant {
|
||||
public:
|
||||
__aicore__ inline MoeV2GatherQuant(){};
|
||||
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
|
||||
GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInIndices(int64_t progress);
|
||||
__aicore__ inline void Compute();
|
||||
__aicore__ inline void CopyOut(int64_t progress);
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inputXCopyInQueue;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> expandRowIdxCopyInQueue;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> inputXCopyOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> floatQueue;
|
||||
TQue<QuePosition::VECOUT, 1> halfQueue;
|
||||
|
||||
GlobalTensor<T> inputXGm;
|
||||
GlobalTensor<int8_t> expandedXGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
GlobalTensor<float> scaleGm;
|
||||
GlobalTensor<float> offsetGm;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
|
||||
|
||||
int64_t needCoreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t cols;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
int64_t activateRows;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t rowLoops;
|
||||
int64_t colsTileLength;
|
||||
int64_t perLoopCols;
|
||||
int64_t lastLoopCols;
|
||||
int64_t colLoops;
|
||||
int64_t dropPadMode;
|
||||
float scale;
|
||||
float offset;
|
||||
|
||||
int64_t indicesOffset;
|
||||
int64_t inputOffset;
|
||||
int64_t outOffset;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherQuant<T>::CopyInIndices(int64_t progress) {
|
||||
this->indicesOffset = progress * this->perLoopRows;
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
|
||||
expandRowIdxCopyInQueue.EnQue<int32_t>(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherQuant<T>::Compute() {
|
||||
LocalTensor<T> inLocal = inputXCopyInQueue.DeQue<T>();
|
||||
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> floatLocal = floatQueue.AllocTensor<float>();
|
||||
LocalTensor<half> halfLocal = halfQueue.AllocTensor<half>();
|
||||
uint32_t elements = Align(this->colsTileLength, sizeof(T));
|
||||
if constexpr (IsSameType<T, bfloat16_t>::value) {
|
||||
Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> intLocal = floatLocal.ReinterpretCast<int32_t>();
|
||||
Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
} else if constexpr (IsSameType<T, float>::value) {
|
||||
Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
|
||||
} else {
|
||||
Muls(inLocal, inLocal, static_cast<T>(this->scale), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Adds(inLocal, inLocal, static_cast<T>(this->offset), elements);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements);
|
||||
}
|
||||
inputXCopyOutQueue.EnQue(outLocal);
|
||||
floatQueue.FreeTensor(floatLocal);
|
||||
halfQueue.FreeTensor(halfLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherQuant<T>::CopyOut(int64_t progress) {
|
||||
LocalTensor<int32_t> indicesLocal = expandRowIdxCopyInQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
colsTileLength = this->perLoopCols;
|
||||
for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) {
|
||||
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
|
||||
int64_t curLoopRow = 0;
|
||||
if (colsLoop == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
int64_t currentLoopStartRow = initialRow / this->k;
|
||||
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
|
||||
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
|
||||
LocalTensor<T> inLocal = inputXCopyInQueue.AllocTensor<T>();
|
||||
// input row position
|
||||
inputOffset = row * this->cols + colsLoop * this->perLoopCols;
|
||||
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams<T> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams);
|
||||
inputXCopyInQueue.EnQue<T>(inLocal);
|
||||
Compute();
|
||||
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.DeQue<int8_t>();
|
||||
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(int8_t)), 0, 0, 0};
|
||||
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
|
||||
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
|
||||
curLoopRow++;
|
||||
initialRow++;
|
||||
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
|
||||
continue;
|
||||
}
|
||||
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
|
||||
DataCopyPad(expandedXGm[outOffset], outLocal, intriParams);
|
||||
}
|
||||
inputXCopyInQueue.FreeTensor(inLocal);
|
||||
inputXCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
}
|
||||
expandRowIdxCopyInQueue.FreeTensor(indicesLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherQuant<T>::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expandedX, GM_ADDR workspace,
|
||||
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) {
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
|
||||
|
||||
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
|
||||
this->activateRows = this->gatherOutTilingData->activateRows;
|
||||
this->cols = tilingData->cols;
|
||||
this->n = tilingData->n;
|
||||
this->k = tilingData->k;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
|
||||
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->gatherOutTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
|
||||
} else {
|
||||
this->coreRows = this->gatherOutTilingData->perCoreRows;
|
||||
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
|
||||
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
|
||||
}
|
||||
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
|
||||
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
|
||||
this->colLoops = this->gatherOutTilingData->colLoops;
|
||||
|
||||
inputXGm.SetGlobalBuffer((__gm__ T*)inputX);
|
||||
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
|
||||
expandedRowIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1);
|
||||
offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1);
|
||||
this->scale = scaleGm.GetValue(0);
|
||||
this->offset = offsetGm.GetValue(0);
|
||||
|
||||
pipe->InitBuffer(inputXCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T)));
|
||||
pipe->InitBuffer(inputXCopyOutQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(int8_t)));
|
||||
pipe->InitBuffer(expandRowIdxCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t)));
|
||||
pipe->InitBuffer(floatQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
|
||||
pipe->InitBuffer(halfQueue, 1, AlignBytes(this->perLoopCols, sizeof(half)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2GatherQuant<T>::Process() {
|
||||
if (this->blockIdx < this->needCoreNum) {
|
||||
currentLoopRows = perLoopRows;
|
||||
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
|
||||
if (loop == this->rowLoops - 1) {
|
||||
currentLoopRows = lastLoopRows;
|
||||
}
|
||||
CopyInIndices(loop);
|
||||
CopyOut(loop);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // MOE_V2_GATHER_QUANT_H
|
||||
@@ -0,0 +1,312 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/* !
|
||||
* \file moe_v2_init_routing_fullload.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_FULL_LOAD_H
|
||||
#define INNER_MOE_V2_FULL_LOAD_H
|
||||
|
||||
#include "moe_v2_mrgsort.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T>
|
||||
class MoeV2FullLoad : public MoeV2SortBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2FullLoad(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
|
||||
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void SortCompute();
|
||||
__aicore__ inline void CopyOutIdx();
|
||||
__aicore__ inline void CopyOutEmpty();
|
||||
__aicore__ inline void CopyOutX();
|
||||
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
|
||||
|
||||
private:
|
||||
int64_t sortNum_;
|
||||
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_;
|
||||
int64_t blockIdx_;
|
||||
int64_t needCoreNum_;
|
||||
int64_t coreRows_;
|
||||
int64_t perCoreRows_;
|
||||
int64_t k_;
|
||||
int64_t n_;
|
||||
int64_t cols_;
|
||||
int64_t activateRows_;
|
||||
int64_t expertNum;
|
||||
int64_t expertCapacity;
|
||||
|
||||
TQue<QuePosition::VECIN, 1> xCopyInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue_;
|
||||
|
||||
GlobalTensor<T> xGm_;
|
||||
GlobalTensor<int32_t> expertIdxGm_;
|
||||
|
||||
GlobalTensor<T> expandedXGm_;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm_;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm_;
|
||||
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
|
||||
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
|
||||
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
int64_t expertTokensBeforeCapacityFlag = 0;
|
||||
int64_t dropPadMode = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::CopyIn() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
|
||||
0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
|
||||
sortDataCopyInQueue.EnQue(inLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::SortCompute() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
|
||||
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
|
||||
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
LocalTensor<float> concatLocal;
|
||||
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
|
||||
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast<uint32_t>();
|
||||
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor<float>();
|
||||
LocalTensor<uint32_t> expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor<uint32_t>();
|
||||
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
|
||||
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
|
||||
this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
|
||||
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
|
||||
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedExpertIdxCopyOutQueue_.EnQue<int32_t>(expandedExpertIdxLocalInt32);
|
||||
|
||||
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor<uint32_t>();
|
||||
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
|
||||
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
expandedRowIdxCopyOutQueue_.EnQue<uint32_t>(expandedRowIdx);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
|
||||
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::CopyOutIdx() {
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = this->totalLength * sizeof(int32_t);
|
||||
DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams);
|
||||
expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::ComputeExpertTokenCountOrCumsum() {
|
||||
LocalTensor<int32_t> expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue_.AllocTensor<int32_t>();
|
||||
|
||||
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
|
||||
Duplicate(expertTokensCount, 0, expertNumAlign);
|
||||
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
|
||||
|
||||
int32_t lastExpertId = expandedExpertIdx.GetValue(0);
|
||||
int64_t tokenCount = 0;
|
||||
int64_t lastExpertCount = 0;
|
||||
for (int64_t i = 0; i < this->totalLength; i++) {
|
||||
int32_t curExpertId = expandedExpertIdx.GetValue(i);
|
||||
tokenCount++;
|
||||
while (lastExpertId < curExpertId) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
|
||||
tokenCount = 1;
|
||||
}
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
|
||||
lastExpertId++;
|
||||
while (lastExpertId < this->expertNum) {
|
||||
expertTokensCount.SetValue(lastExpertId, tokenCount);
|
||||
lastExpertId++;
|
||||
}
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
|
||||
0};
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
|
||||
}
|
||||
expertTokensCopyOutQueue_.FreeTensor(expertTokensCount);
|
||||
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::CopyOutX() {
|
||||
LocalTensor<T> xLocal = xCopyInQueue_.AllocTensor<T>();
|
||||
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = this->cols_ * sizeof(T);
|
||||
int64_t inFactor = Align(this->cols_, sizeof(T));
|
||||
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
|
||||
int64_t startXRow = curRowsStart / this->k_;
|
||||
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
|
||||
|
||||
DataCopyExtParams dataXCopyParams{static_cast<uint16_t>(endXRow - startXRow + 1),
|
||||
static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPadExtParams<T> dataXCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(xLocal, xGm_[startXRow * this->cols_], dataXCopyParams, dataXCopyPadParams);
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
|
||||
int64_t k = 0;
|
||||
for (int64_t i = startXRow; i <= endXRow; i++) {
|
||||
for (; k < this->perCoreRows_ && curRowsStart / this->k_ == i; curRowsStart++, k++) {
|
||||
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
|
||||
if (outIndex < this->activateRows_) {
|
||||
DataCopyPad(expandedXGm_[outIndex * this->cols_], xLocal[(i - startXRow) * inFactor], intriParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
|
||||
xCopyInQueue_.FreeTensor(xLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::CopyOutEmpty() {
|
||||
LocalTensor<int32_t> outLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
|
||||
expandedExpertIdxCopyOutQueue_.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
|
||||
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe) {
|
||||
this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp);
|
||||
this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->k_ = tilingData->k;
|
||||
this->n_ = tilingData->n;
|
||||
this->cols_ = tilingData->cols;
|
||||
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
|
||||
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
|
||||
this->activateRows_ = this->gatherOutTilingData_->activateRows;
|
||||
if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) {
|
||||
this->coreRows_ = this->gatherOutTilingData_->lastCoreRows;
|
||||
} else {
|
||||
this->coreRows_ = this->gatherOutTilingData_->perCoreRows;
|
||||
}
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->dropPadMode = tilingData->dropPadMode;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
|
||||
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
|
||||
this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->pipe = tPipe;
|
||||
|
||||
xGm_.SetGlobalBuffer((__gm__ T*)x);
|
||||
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
|
||||
|
||||
expandedXGm_.SetGlobalBuffer((__gm__ T*)expandedX);
|
||||
expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
// dropless
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
|
||||
Align(this->expertNum, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
int64_t kvFactor = 2;
|
||||
int64_t buffSize = this->sortNum_ * sizeof(int32_t);
|
||||
|
||||
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
|
||||
int64_t startXRow = curRowsStart / this->k_;
|
||||
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
|
||||
pipe->InitBuffer(xCopyInQueue_, bufferNum, AlignBytes(this->cols_, sizeof(T)) * (endXRow - startXRow + 1));
|
||||
pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
|
||||
pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize);
|
||||
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
|
||||
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
|
||||
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void MoeV2FullLoad<T>::Process() {
|
||||
if (this->blockIdx_ < this->needCoreNum_) {
|
||||
CopyIn();
|
||||
SortCompute();
|
||||
if (this->blockIdx_ == 0) {
|
||||
CopyOutIdx();
|
||||
}
|
||||
if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
ComputeExpertTokenCountOrCumsum();
|
||||
} else {
|
||||
CopyOutEmpty();
|
||||
}
|
||||
CopyOutX();
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_FULL_LOAD_H
|
||||
@@ -0,0 +1,189 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_mrgsort.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_MRGSORT_H
|
||||
#define INNER_MOE_V2_MRGSORT_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
struct MoeV2MrgsortParam {
|
||||
int64_t perListElements;
|
||||
int64_t lastListElements;
|
||||
int64_t oneLoopMaxElements;
|
||||
};
|
||||
|
||||
class MoeV2Mrgsort {
|
||||
public:
|
||||
__aicore__ inline MoeV2Mrgsort(){};
|
||||
__aicore__ inline void Init(MoeV2MrgsortParam* param);
|
||||
__aicore__ inline void Process();
|
||||
__aicore__ inline void SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput);
|
||||
__aicore__ inline void SetOutput(GlobalTensor<float>& gmOutput, LocalTensor<float>& ubOutput);
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void UpdateMrgParam();
|
||||
__aicore__ inline void MrgsortCompute();
|
||||
__aicore__ inline void UpdateSortInfo();
|
||||
__aicore__ inline void CopyOut();
|
||||
__aicore__ inline void ClearCache();
|
||||
|
||||
private:
|
||||
MoeV2MrgsortParam* param = nullptr;
|
||||
|
||||
GlobalTensor<float> gmInputs[4];
|
||||
GlobalTensor<float> gmOutput;
|
||||
|
||||
LocalTensor<float> ubInputs[4];
|
||||
LocalTensor<float> ubOutput;
|
||||
|
||||
int64_t listNum{0};
|
||||
int64_t remainListNum{0};
|
||||
int64_t outOffset{0};
|
||||
int64_t offsets[4];
|
||||
int64_t listRemainElements[4];
|
||||
int64_t lengths[4];
|
||||
int64_t allRemainElements{0};
|
||||
int64_t curLoopSortedNum{0};
|
||||
|
||||
// for MrgSort
|
||||
uint16_t validBitTail{0};
|
||||
uint16_t elementCountListTail[4];
|
||||
uint32_t listSortedNums[4];
|
||||
LocalTensor<float> tmpUbInputs[4];
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::ClearCache() {
|
||||
this->listNum = 0;
|
||||
this->allRemainElements = 0;
|
||||
this->outOffset = 0;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput) {
|
||||
this->gmInputs[listNum] = gmInput;
|
||||
this->ubInputs[listNum] = ubInput;
|
||||
this->listNum += 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::SetOutput(GlobalTensor<float>& gmOutput, LocalTensor<float>& ubOutput) {
|
||||
this->gmOutput = gmOutput;
|
||||
this->ubOutput = ubOutput;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::UpdateMrgParam() {
|
||||
if (this->remainListNum == MERGE_LIST_TWO) {
|
||||
elementCountListTail[MERGE_LIST_IDX_TWO] = 0;
|
||||
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
|
||||
validBitTail = 0b0011;
|
||||
} else if (this->remainListNum == MERGE_LIST_THREE) {
|
||||
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
|
||||
validBitTail = 0b0111;
|
||||
} else if (this->remainListNum == MERGE_LIST_FOUR) {
|
||||
validBitTail = 0b1111;
|
||||
} else {
|
||||
validBitTail = 0b0001;
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::CopyIn() {
|
||||
this->remainListNum = 0;
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
for (int64_t i = 0, j = 0; i < listNum; i++) {
|
||||
lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]);
|
||||
if (lengths[i] > 0) {
|
||||
DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen<float>(lengths[i]), sizeof(float)));
|
||||
tmpUbInputs[j] = this->ubInputs[i];
|
||||
elementCountListTail[j] = lengths[i];
|
||||
this->remainListNum += 1;
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::MrgsortCompute() {
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
if (this->remainListNum == MERGE_LIST_TWO) {
|
||||
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]);
|
||||
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else if (this->remainListNum == MERGE_LIST_THREE) {
|
||||
MrgSortSrcList sortListTail =
|
||||
MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]);
|
||||
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else if (this->remainListNum == MERGE_LIST_FOUR) {
|
||||
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO],
|
||||
tmpUbInputs[MERGE_LIST_IDX_THREE]);
|
||||
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else {
|
||||
DataCopy(this->ubOutput, this->tmpUbInputs[0], Align(GetSortLen<float>(elementCountListTail[0]), sizeof(float)));
|
||||
listSortedNums[0] = elementCountListTail[0];
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::UpdateSortInfo() {
|
||||
curLoopSortedNum = 0;
|
||||
for (int64_t i = 0, j = 0; i < listNum; i++) {
|
||||
if (lengths[i] > 0) {
|
||||
// update remain size
|
||||
listRemainElements[i] -= listSortedNums[j];
|
||||
allRemainElements -= listSortedNums[j];
|
||||
// update offset
|
||||
offsets[i] += GetSortOffset<float>(listSortedNums[j]);
|
||||
// update current loop sorted nums
|
||||
curLoopSortedNum += listSortedNums[j];
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::CopyOut() {
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = GetSortLen<float>(curLoopSortedNum) * sizeof(float);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams);
|
||||
outOffset += GetSortLen<float>(curLoopSortedNum);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::Init(MoeV2MrgsortParam* param) {
|
||||
this->param = param;
|
||||
this->remainListNum = listNum;
|
||||
|
||||
for (int64_t i = 0; i < listNum; i++) {
|
||||
offsets[i] = GetSortOffset<float>(param->perListElements * i);
|
||||
if (i == listNum - 1) {
|
||||
listRemainElements[i] = param->lastListElements;
|
||||
} else {
|
||||
listRemainElements[i] = param->perListElements;
|
||||
}
|
||||
allRemainElements += listRemainElements[i];
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2Mrgsort::Process() {
|
||||
for (; allRemainElements > 0;) {
|
||||
CopyIn();
|
||||
UpdateMrgParam();
|
||||
MrgsortCompute();
|
||||
UpdateSortInfo();
|
||||
CopyOut();
|
||||
}
|
||||
|
||||
ClearCache();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_MRGSORT_H
|
||||
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_mrgsort_out.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_MRGSORT_OUT_H
|
||||
#define INNER_MOE_V2_MRGSORT_OUT_H
|
||||
|
||||
#include "moe_v2_mrgsort.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2MrgsortOut {
|
||||
public:
|
||||
__aicore__ inline MoeV2MrgsortOut(){};
|
||||
__aicore__ inline void Init(MoeV2MrgsortParam* param, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
__aicore__ inline void SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput);
|
||||
__aicore__ inline void SetOutput(GlobalTensor<int32_t>& gmOutput1, GlobalTensor<int32_t>& gmOutput2,
|
||||
LocalTensor<float>& ubOutput1, LocalTensor<float>& ubOutput2);
|
||||
__aicore__ inline void SetBuffer(LocalTensor<float>& tempBuffer);
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void UpdateMrgParam();
|
||||
__aicore__ inline void MrgsortCompute();
|
||||
__aicore__ inline void UpdateSortInfo();
|
||||
__aicore__ inline void Extract();
|
||||
__aicore__ inline void CopyOut();
|
||||
__aicore__ inline void ClearCache();
|
||||
|
||||
private:
|
||||
MoeV2MrgsortParam* param = nullptr;
|
||||
|
||||
GlobalTensor<float> gmInputs[4];
|
||||
GlobalTensor<int32_t> gmOutput1;
|
||||
GlobalTensor<int32_t> gmOutput2;
|
||||
|
||||
LocalTensor<float> ubInputs[4];
|
||||
LocalTensor<float> tempBuffer;
|
||||
|
||||
// for extract
|
||||
LocalTensor<float> ubOutput1;
|
||||
LocalTensor<uint32_t> ubOutput2;
|
||||
|
||||
// for copy out
|
||||
LocalTensor<int32_t> ubOutputInt1;
|
||||
LocalTensor<int32_t> ubOutputInt2;
|
||||
|
||||
int64_t listNum{0};
|
||||
int64_t remainListNum{0};
|
||||
int64_t outOffset{0};
|
||||
int64_t offsets[4];
|
||||
int64_t listRemainElements[4];
|
||||
int64_t lengths[4];
|
||||
int64_t allRemainElements{0};
|
||||
int64_t curLoopSortedNum{0};
|
||||
|
||||
// for MrgSort
|
||||
uint16_t validBitTail;
|
||||
uint16_t elementCountListTail[4];
|
||||
uint32_t listSortedNums[4];
|
||||
LocalTensor<float> tmpUbInputs[4];
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::ClearCache() {
|
||||
this->listNum = 0;
|
||||
this->allRemainElements = 0;
|
||||
this->outOffset = 0;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput) {
|
||||
this->gmInputs[listNum] = gmInput;
|
||||
this->ubInputs[listNum] = ubInput;
|
||||
this->listNum += 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::SetOutput(GlobalTensor<int32_t>& gmOutput1, GlobalTensor<int32_t>& gmOutput2,
|
||||
LocalTensor<float>& ubOutput1, LocalTensor<float>& ubOutput2) {
|
||||
this->gmOutput1 = gmOutput1;
|
||||
this->ubOutput1 = ubOutput1;
|
||||
this->ubOutputInt1 = ubOutput1.ReinterpretCast<int32_t>();
|
||||
|
||||
this->gmOutput2 = gmOutput2;
|
||||
this->ubOutput2 = ubOutput2.ReinterpretCast<uint32_t>();
|
||||
this->ubOutputInt2 = ubOutput2.ReinterpretCast<int32_t>();
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::SetBuffer(LocalTensor<float>& tempBuffer) {
|
||||
this->tempBuffer = tempBuffer;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::UpdateMrgParam() {
|
||||
if (this->remainListNum == MERGE_LIST_TWO) {
|
||||
elementCountListTail[MERGE_LIST_IDX_TWO] = 0;
|
||||
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
|
||||
validBitTail = 0b0011;
|
||||
} else if (this->remainListNum == MERGE_LIST_THREE) {
|
||||
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
|
||||
validBitTail = 0b0111;
|
||||
} else if (this->remainListNum == MERGE_LIST_FOUR) {
|
||||
validBitTail = 0b1111;
|
||||
} else {
|
||||
validBitTail = 0b0001;
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::CopyIn() {
|
||||
this->remainListNum = 0;
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
for (int64_t i = 0, j = 0; i < listNum; i++) {
|
||||
lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]);
|
||||
if (lengths[i] > 0) {
|
||||
DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen<float>(lengths[i]), sizeof(float)));
|
||||
tmpUbInputs[j] = this->ubInputs[i];
|
||||
elementCountListTail[j] = lengths[i];
|
||||
this->remainListNum += 1;
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::MrgsortCompute() {
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
if (this->remainListNum == MERGE_LIST_TWO) {
|
||||
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]);
|
||||
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else if (this->remainListNum == MERGE_LIST_THREE) {
|
||||
MrgSortSrcList sortListTail =
|
||||
MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]);
|
||||
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else if (this->remainListNum == MERGE_LIST_FOUR) {
|
||||
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO],
|
||||
tmpUbInputs[MERGE_LIST_IDX_THREE]);
|
||||
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
|
||||
} else {
|
||||
DataCopy(this->tempBuffer, this->tmpUbInputs[0], Align(GetSortLen<float>(elementCountListTail[0]), sizeof(float)));
|
||||
listSortedNums[0] = elementCountListTail[0];
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::UpdateSortInfo() {
|
||||
curLoopSortedNum = 0;
|
||||
for (int64_t i = 0, j = 0; i < listNum; i++) {
|
||||
if (lengths[i] > 0) {
|
||||
// update remain size
|
||||
listRemainElements[i] -= listSortedNums[j];
|
||||
allRemainElements -= listSortedNums[j];
|
||||
// update offset
|
||||
offsets[i] += GetSortOffset<float>(listSortedNums[j]);
|
||||
// update current loop sorted nums
|
||||
curLoopSortedNum += listSortedNums[j];
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::Extract() {
|
||||
AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM));
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float)));
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float)));
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::CopyOut() {
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = curLoopSortedNum * sizeof(int32_t);
|
||||
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
|
||||
DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams);
|
||||
DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams);
|
||||
outOffset += curLoopSortedNum;
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::Init(MoeV2MrgsortParam* param, TPipe* tPipe) {
|
||||
this->param = param;
|
||||
this->allRemainElements = 0;
|
||||
for (int64_t i = 0; i < listNum; i++) {
|
||||
offsets[i] = GetSortOffset<float>(param->perListElements * i);
|
||||
if (i == listNum - 1) {
|
||||
listRemainElements[i] = param->lastListElements;
|
||||
} else {
|
||||
listRemainElements[i] = param->perListElements;
|
||||
}
|
||||
allRemainElements += listRemainElements[i];
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2MrgsortOut::Process() {
|
||||
for (; allRemainElements > 0;) {
|
||||
CopyIn();
|
||||
UpdateMrgParam();
|
||||
MrgsortCompute();
|
||||
UpdateSortInfo();
|
||||
Extract();
|
||||
CopyOut();
|
||||
}
|
||||
ClearCache();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_MRGSORT_OUT_H
|
||||
@@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_sort_base.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_SORT_BASE_H
|
||||
#define INNER_MOE_V2_SORT_BASE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2SortBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2SortBase(){};
|
||||
|
||||
protected:
|
||||
__aicore__ inline void SyncAll();
|
||||
|
||||
protected:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, 1> sortDataCopyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> sortDataCopyOutQueue;
|
||||
TBuf<TPosition::VECCALC> tempBuffer;
|
||||
TBuf<TPosition::VECCALC> sortedBuffer;
|
||||
|
||||
GlobalTensor<int32_t> expertIdxGm;
|
||||
GlobalTensor<int32_t> sortedexpertIdxGm;
|
||||
GlobalTensor<int32_t> expandDstToSrcRowGm;
|
||||
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
|
||||
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
|
||||
|
||||
int64_t tileLength;
|
||||
int64_t bufferNum = 1;
|
||||
int64_t totalLength;
|
||||
int64_t coreNum;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
int64_t existRowIdx;
|
||||
int64_t expertNum;
|
||||
int64_t expertTokensCountOrCumsumFlag = 0;
|
||||
int64_t expertTokensBeforeCapacityFlag = 0;
|
||||
|
||||
static constexpr int64_t SYNC_GM_NUM = 2;
|
||||
static constexpr int64_t WORK_GM_NUM = 2;
|
||||
static constexpr int64_t DST_BLK_STRIDE = 1;
|
||||
static constexpr int64_t DST_REP_STRIDE = 8;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2SortBase::SyncAll() {
|
||||
if (coreNum == 1) {
|
||||
return;
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_SORT_BASE_H
|
||||
@@ -0,0 +1,373 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_sort_multi_core.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_VBS_ONE_CORE_H
|
||||
#define INNER_MOE_V2_VBS_ONE_CORE_H
|
||||
|
||||
#include "moe_v2_sort_base.h"
|
||||
#include "moe_v2_mrgsort.h"
|
||||
#include "moe_v2_mrgsort_out.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2SortMultiCore : public MoeV2SortBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2SortMultiCore(){};
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
|
||||
GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void VBSProcess();
|
||||
__aicore__ inline void UBSortProcess(int64_t progress, int64_t size, int64_t sortNum);
|
||||
__aicore__ inline void OneCoreVMSProcess(int64_t listNum, int64_t perListElements, int64_t lastListElements);
|
||||
__aicore__ inline void VMSProcess();
|
||||
__aicore__ inline void SortOutProcess();
|
||||
__aicore__ inline void VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum);
|
||||
__aicore__ inline void UBSortCompute(int64_t progress, int64_t size, int64_t sortNum);
|
||||
__aicore__ inline void VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum);
|
||||
__aicore__ inline void InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset, int64_t loopOffset);
|
||||
__aicore__ inline void InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum, int64_t coreOffset);
|
||||
__aicore__ inline void InitExpertTokensGlobalMemory();
|
||||
|
||||
private:
|
||||
GlobalTensor<float> workspaceGms[2];
|
||||
|
||||
const InnerMoeV2VBSComputeTilingData* vbsTilingData;
|
||||
const InnerMoeV2VMSMiddleComputeTilingData* vmsTilingData;
|
||||
const InnerMoeV2SortOutComputeTilingData* sortOutTilingData;
|
||||
|
||||
// for MoeMrgsort
|
||||
MoeV2Mrgsort mrgsorter;
|
||||
MoeV2MrgsortParam mrgsortParam;
|
||||
|
||||
int64_t coreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t srcWsIndex = 0;
|
||||
|
||||
int64_t listNum;
|
||||
int64_t perListElements;
|
||||
int64_t lastListElements;
|
||||
|
||||
int64_t sortTotalLength;
|
||||
int64_t sortCoreLoops;
|
||||
int64_t sortCoreLoopElements;
|
||||
int64_t sortCoreLastLoopElements;
|
||||
|
||||
int64_t perCoreExpert;
|
||||
int64_t needInitExpertCore;
|
||||
int64_t currentCoreExpert;
|
||||
|
||||
static constexpr int64_t MAX_MRGSORT_LIST = 4;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::InitExpertTokensGlobalMemory() {
|
||||
if (this->blockIdx < this->needInitExpertCore) {
|
||||
if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
InitGlobalMemory(expertTokensCountOrCumsumGm, currentCoreExpert, 0);
|
||||
}
|
||||
if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
|
||||
InitGlobalMemory(expertTokensBeforeCapacityGm, currentCoreExpert, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum) {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
|
||||
int64_t inOffset = progress * sortCoreLoopElements;
|
||||
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(size * sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams);
|
||||
|
||||
LocalTensor<int32_t> rowIdxLocal = inLocal[sortNum];
|
||||
int64_t startValue = this->blockIdx * this->vbsTilingData->perCoreElements + inOffset;
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
ArithProgression<int32_t>(rowIdxLocal, startValue, 1, size);
|
||||
sortDataCopyInQueue.EnQue(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::UBSortCompute(int64_t progress, int64_t size, int64_t sortNum) {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertForSourceRowLocal = inLocal[0];
|
||||
LocalTensor<float> expertForSourceRowLocalFp32;
|
||||
|
||||
expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast<float>();
|
||||
Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, sortNum);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, sortNum);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t duplicateNum = size % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = size - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
LocalTensor<float> concatLocal = expertForSourceRowLocalFp32;
|
||||
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(sortNum));
|
||||
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
|
||||
LocalTensor<uint32_t> sourceRowLocal;
|
||||
sourceRowLocal = inLocal[sortNum].ReinterpretCast<uint32_t>();
|
||||
Sort<float, true>(outLocal, concatLocal, sourceRowLocal, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM);
|
||||
|
||||
sortDataCopyOutQueue.EnQue<float>(outLocal);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum) {
|
||||
LocalTensor<float> outLocal = sortDataCopyOutQueue.DeQue<float>();
|
||||
DataCopy(workspaceGms[0][this->blockIdx * GetSortLen<float>(this->vbsTilingData->perCoreElements) +
|
||||
GetSortLen<float>(progress * sortCoreLoopElements)],
|
||||
outLocal, Align(GetSortLen<float>(size), sizeof(float)));
|
||||
sortDataCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset,
|
||||
int64_t loopOffset) {
|
||||
GlobalTensor<float> srcWsGm = workspaceGms[srcWsIndex][blockIdx * coreOffset + loopOffset];
|
||||
LocalTensor<float> inLocal = sortDataCopyInQueue.AllocTensor<float>();
|
||||
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
|
||||
for (int64_t i = 0; i < listNum; i++) {
|
||||
LocalTensor<float> inLocalT = inLocal[GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * i];
|
||||
sorter->SetInput(srcWsGm, inLocalT);
|
||||
}
|
||||
GlobalTensor<float> dstWsGm = workspaceGms[1 - srcWsIndex][blockIdx * coreOffset + loopOffset];
|
||||
sorter->SetOutput(dstWsGm, outLocal);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
sortDataCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum,
|
||||
int64_t coreOffset) {
|
||||
GlobalTensor<float> srcWsGm = workspaceGms[srcWsIndex];
|
||||
LocalTensor<float> inLocal = sortDataCopyInQueue.AllocTensor<float>();
|
||||
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
|
||||
|
||||
for (int64_t i = 0; i < listNum; i++) {
|
||||
LocalTensor<float> inLocalT = inLocal[GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * i];
|
||||
sorter->SetInput(srcWsGm, inLocalT);
|
||||
}
|
||||
|
||||
LocalTensor<float> outLocalV = outLocal[this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST];
|
||||
sorter->SetOutput(this->sortedexpertIdxGm, this->expandDstToSrcRowGm, outLocal, outLocalV);
|
||||
|
||||
LocalTensor<float> tempBuffer =
|
||||
sortedBuffer.Get<float>(GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST);
|
||||
sorter->SetBuffer(tempBuffer);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
sortDataCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::OneCoreVMSProcess(int64_t listNum, int64_t perListElements,
|
||||
int64_t lastListElements) {
|
||||
int64_t coreOffset = GetSortLen<float>(this->vbsTilingData->perCoreElements);
|
||||
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
|
||||
|
||||
for (int64_t i = 0; listNum >= 1; i++) {
|
||||
int64_t loops = (listNum + MAX_MRGSORT_LIST - 1) / MAX_MRGSORT_LIST;
|
||||
int64_t remainListNum = listNum - (loops - 1) * MAX_MRGSORT_LIST;
|
||||
|
||||
mrgsortParam.perListElements = perListElements;
|
||||
mrgsortParam.lastListElements = perListElements;
|
||||
|
||||
int64_t loopOffset = GetSortLen<float>(mrgsortParam.perListElements * MAX_MRGSORT_LIST);
|
||||
for (int64_t loop = 0; loop < loops - 1; loop++) {
|
||||
InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, loop * loopOffset);
|
||||
mrgsorter.Init(&mrgsortParam);
|
||||
mrgsorter.Process();
|
||||
}
|
||||
|
||||
mrgsortParam.perListElements = perListElements;
|
||||
mrgsortParam.lastListElements = lastListElements;
|
||||
InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, (loops - 1) * loopOffset);
|
||||
mrgsorter.Init(&mrgsortParam);
|
||||
mrgsorter.Process();
|
||||
|
||||
listNum = loops;
|
||||
lastListElements = perListElements * (remainListNum - 1) + lastListElements;
|
||||
perListElements = perListElements * MAX_MRGSORT_LIST;
|
||||
srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM;
|
||||
|
||||
if (loops == 1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::UBSortProcess(int64_t progress, int64_t size, int64_t sortNum) {
|
||||
VBSCopyIn(progress, size, sortNum);
|
||||
UBSortCompute(progress, size, sortNum);
|
||||
VBSCopyOut(progress, size, sortNum);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::VBSProcess() {
|
||||
if (this->blockIdx < this->vbsTilingData->needCoreNum) {
|
||||
int64_t sortNum = Ceil(sortCoreLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
for (int64_t loop = 0; loop < sortCoreLoops - 1; loop++) {
|
||||
UBSortProcess(loop, sortCoreLoopElements, sortNum);
|
||||
}
|
||||
|
||||
sortNum = Ceil(sortCoreLastLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
UBSortProcess(sortCoreLoops - 1, sortCoreLastLoopElements, sortNum);
|
||||
if (sortCoreLoops > 1) {
|
||||
OneCoreVMSProcess(sortCoreLoops, sortCoreLoopElements, sortCoreLastLoopElements);
|
||||
}
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::VMSProcess() {
|
||||
int64_t currentStageNeedCoreNum = this->vmsTilingData->needCoreNum;
|
||||
perListElements = this->vbsTilingData->perCoreElements;
|
||||
lastListElements = this->vbsTilingData->lastCoreElements;
|
||||
listNum = this->vbsTilingData->needCoreNum;
|
||||
|
||||
for (; listNum > MAX_MRGSORT_LIST;) {
|
||||
currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST);
|
||||
int64_t coreOffset = GetSortLen<float>(perListElements * MAX_MRGSORT_LIST);
|
||||
int64_t remainListNum = listNum - (currentStageNeedCoreNum - 1) * MAX_MRGSORT_LIST;
|
||||
|
||||
if (this->blockIdx < currentStageNeedCoreNum - 1) {
|
||||
mrgsortParam.perListElements = perListElements;
|
||||
mrgsortParam.lastListElements = perListElements;
|
||||
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
|
||||
InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, 0);
|
||||
mrgsorter.Init(&mrgsortParam);
|
||||
mrgsorter.Process();
|
||||
} else if (this->blockIdx == currentStageNeedCoreNum - 1) {
|
||||
mrgsortParam.perListElements = perListElements;
|
||||
mrgsortParam.lastListElements = lastListElements;
|
||||
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
|
||||
InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, 0);
|
||||
mrgsorter.Init(&mrgsortParam);
|
||||
mrgsorter.Process();
|
||||
}
|
||||
listNum = currentStageNeedCoreNum;
|
||||
currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST);
|
||||
srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM;
|
||||
|
||||
lastListElements = perListElements * (remainListNum - 1) + lastListElements;
|
||||
perListElements = perListElements * MAX_MRGSORT_LIST;
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::SortOutProcess() {
|
||||
if (this->blockIdx < 1) {
|
||||
mrgsortParam.perListElements = perListElements;
|
||||
mrgsortParam.lastListElements = lastListElements;
|
||||
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
|
||||
|
||||
MoeV2MrgsortOut sorter;
|
||||
InitMoeMrgSortOut(&sorter, listNum, GetSortLen<float>(perListElements));
|
||||
sorter.Init(&mrgsortParam, pipe);
|
||||
sorter.Process();
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void MoeV2SortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum,
|
||||
GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace,
|
||||
const TilingData* tilingData, TPipe* tPipe) {
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->vbsTilingData = &(tilingData->vbsComputeParamsOp);
|
||||
this->vmsTilingData = &(tilingData->vmsMiddleComputeParamsOp);
|
||||
this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp);
|
||||
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->tileLength = this->vbsTilingData->perCorePerLoopElements;
|
||||
this->sortTotalLength = this->vbsTilingData->perCoreElements;
|
||||
if (this->blockIdx == tilingData->vbsComputeParamsOp.needCoreNum - 1) {
|
||||
this->tileLength = this->vbsTilingData->lastCorePerLoopElements;
|
||||
this->sortTotalLength = this->vbsTilingData->lastCoreElements;
|
||||
}
|
||||
this->n = tilingData->n;
|
||||
this->k = tilingData->k;
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
|
||||
|
||||
// VBS param init
|
||||
if (this->blockIdx == this->vbsTilingData->needCoreNum - 1) {
|
||||
sortCoreLoops = this->vbsTilingData->lastCoreLoops;
|
||||
sortCoreLoopElements = this->vbsTilingData->lastCorePerLoopElements;
|
||||
sortCoreLastLoopElements = this->vbsTilingData->lastCoreLastLoopElements;
|
||||
} else {
|
||||
sortCoreLoops = this->vbsTilingData->perCoreLoops;
|
||||
sortCoreLoopElements = this->vbsTilingData->perCorePerLoopElements;
|
||||
sortCoreLastLoopElements = this->vbsTilingData->perCoreLastLoopElements;
|
||||
}
|
||||
|
||||
this->pipe = tPipe;
|
||||
expertIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expertIdx + this->blockIdx * tilingData->vbsComputeParamsOp.perCoreElements,
|
||||
this->sortTotalLength);
|
||||
sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace),
|
||||
Align(this->totalLength, sizeof(int32_t)));
|
||||
expandDstToSrcRowGm.SetGlobalBuffer(
|
||||
reinterpret_cast<__gm__ int32_t*>(workspace) + Align(this->totalLength, sizeof(int32_t)),
|
||||
Align(this->totalLength, sizeof(int32_t)));
|
||||
|
||||
this->perCoreExpert = Align((this->expertNum + this->coreNum - 1) / this->coreNum, sizeof(int32_t));
|
||||
this->needInitExpertCore = (this->expertNum + this->perCoreExpert - 1) / this->perCoreExpert;
|
||||
this->currentCoreExpert = this->perCoreExpert;
|
||||
if (this->blockIdx == needInitExpertCore - 1) {
|
||||
this->currentCoreExpert = this->expertNum - (this->needInitExpertCore - 1) * this->perCoreExpert;
|
||||
}
|
||||
if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expertTokensCountOrCumsum + this->blockIdx * this->perCoreExpert, this->currentCoreExpert);
|
||||
}
|
||||
if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
|
||||
expertTokensBeforeCapacityGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)expertTokensBeforeCapacity + this->blockIdx * this->perCoreExpert, this->currentCoreExpert);
|
||||
}
|
||||
// key and value
|
||||
int64_t kvFactor = 2;
|
||||
workspaceGms[0].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2,
|
||||
Align(this->totalLength, sizeof(int32_t)) * kvFactor);
|
||||
workspaceGms[1].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * (kvFactor + 2),
|
||||
Align(this->totalLength, sizeof(int32_t)) * kvFactor);
|
||||
|
||||
int64_t bufferSize = Ceil(Max(this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST, sortCoreLoopElements),
|
||||
ONE_REPEAT_SORT_NUM) *
|
||||
ONE_REPEAT_SORT_NUM * sizeof(int32_t) * kvFactor;
|
||||
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize);
|
||||
pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize);
|
||||
pipe->InitBuffer(sortedBuffer, bufferSize);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortMultiCore::Process() {
|
||||
InitExpertTokensGlobalMemory();
|
||||
VBSProcess();
|
||||
VMSProcess();
|
||||
SortOutProcess();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_VBS_ONE_CORE_H
|
||||
@@ -0,0 +1,162 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_sort_one_core.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_SORT_ONE_CORE_H
|
||||
#define INNER_MOE_V2_SORT_ONE_CORE_H
|
||||
|
||||
#include "moe_v2_mrgsort.h"
|
||||
#include "moe_v2_sort_base.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2SortOneCore : public MoeV2SortBase {
|
||||
public:
|
||||
__aicore__ inline MoeV2SortOneCore(){};
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
|
||||
GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn();
|
||||
__aicore__ inline void SortCompute();
|
||||
__aicore__ inline void CopyOut();
|
||||
|
||||
private:
|
||||
int64_t sortNum;
|
||||
int64_t blockIdx;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2SortOneCore::CopyIn() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
|
||||
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
|
||||
0, 0, 0};
|
||||
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
|
||||
DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams);
|
||||
|
||||
LocalTensor<int32_t> rowIdxLocal = inLocal[this->sortNum];
|
||||
ArithProgression<int32_t>(rowIdxLocal, 0, 1, this->sortNum);
|
||||
sortDataCopyInQueue.EnQue(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortOneCore::SortCompute() {
|
||||
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> expertForSourceRowLocal = inLocal[0];
|
||||
LocalTensor<float> expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast<float>();
|
||||
Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, this->tileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
|
||||
if (duplicateNum > 0) {
|
||||
int duplicateIndex = this->totalLength - duplicateNum;
|
||||
uint64_t mask0 = UINT64_MAX;
|
||||
mask0 = mask0 << duplicateNum;
|
||||
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
|
||||
uint64_t mask[2] = {mask0, 0};
|
||||
Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
LocalTensor<float> concatLocal;
|
||||
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum));
|
||||
Concat(concatLocal, expertForSourceRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum));
|
||||
LocalTensor<uint32_t> sourceRowLocal;
|
||||
sourceRowLocal = inLocal[this->sortNum].ReinterpretCast<uint32_t>();
|
||||
Sort<float, true>(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
|
||||
LocalTensor<float> sortedExpertForSourceRowLocal = outLocal[0];
|
||||
LocalTensor<uint32_t> expandDstToSrcRowLocal;
|
||||
expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast<uint32_t>();
|
||||
Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
|
||||
pipe_barrier(PIPE_V);
|
||||
Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
LocalTensor<int32_t> expertForSourceRowLocalInt32;
|
||||
expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast<int32_t>();
|
||||
Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength);
|
||||
sortDataCopyOutQueue.EnQue<float>(outLocal);
|
||||
sortDataCopyInQueue.FreeTensor(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortOneCore::CopyOut() {
|
||||
LocalTensor<int32_t> outLocal = sortDataCopyOutQueue.DeQue<int32_t>();
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = this->totalLength * sizeof(int32_t);
|
||||
DataCopyPad(sortedexpertIdxGm, outLocal[0], intriParams);
|
||||
DataCopyPad(expandDstToSrcRowGm, outLocal[this->sortNum], intriParams);
|
||||
sortDataCopyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void MoeV2SortOneCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum,
|
||||
GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace,
|
||||
const TilingData* tilingData, TPipe* tPipe) {
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
|
||||
this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->pipe = tPipe;
|
||||
this->n = tilingData->n;
|
||||
this->k = tilingData->k;
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
|
||||
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
|
||||
|
||||
expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
|
||||
sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace), this->tileLength);
|
||||
expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace) + this->tileLength,
|
||||
this->tileLength);
|
||||
|
||||
if (this->blockIdx == this->coreNum - 1) {
|
||||
if (this->expertTokensCountOrCumsumFlag > 0) {
|
||||
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
|
||||
Align(this->expertNum, sizeof(int32_t)));
|
||||
InitGlobalMemory(expertTokensCountOrCumsumGm, this->expertNum, 0);
|
||||
}
|
||||
if (this->expertTokensBeforeCapacityFlag == 1) {
|
||||
expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity,
|
||||
Align(this->expertNum, sizeof(int32_t)));
|
||||
InitGlobalMemory(expertTokensBeforeCapacityGm, this->expertNum, 0);
|
||||
}
|
||||
}
|
||||
// key and value
|
||||
int64_t kvFactor = 2;
|
||||
int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor;
|
||||
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize);
|
||||
pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize);
|
||||
pipe->InitBuffer(tempBuffer, buffSize);
|
||||
pipe->InitBuffer(sortedBuffer, buffSize);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SortOneCore::Process() {
|
||||
if (get_block_idx() + get_subblockid() * get_block_num() < 1) {
|
||||
CopyIn();
|
||||
SortCompute();
|
||||
CopyOut();
|
||||
}
|
||||
this->SyncAll();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_SORT_ONE_CORE_H
|
||||
@@ -0,0 +1,560 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_src_to_dst_and_gather.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef MOE_V2_SRC_TO_DST_AND_GATHER_H
|
||||
#define MOE_V2_SRC_TO_DST_AND_GATHER_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T, typename TilingData>
|
||||
class MoeV2SrcToDstAndGather {
|
||||
public:
|
||||
__aicore__ inline MoeV2SrcToDstAndGather(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
|
||||
GM_ADDR dynamicQuantScale, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(int64_t progress);
|
||||
__aicore__ inline void CopyOut(int64_t progress);
|
||||
__aicore__ inline void CopyOutLoops(int64_t progress);
|
||||
__aicore__ inline void Compute(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx);
|
||||
__aicore__ inline float ComputeMax(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal,
|
||||
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx,
|
||||
int64_t j);
|
||||
__aicore__ inline void ComputeScale(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal, float scaleTemp,
|
||||
int64_t dstIndex, int64_t j);
|
||||
__aicore__ inline void ComputeLoops(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx);
|
||||
|
||||
__aicore__ inline void CopyOutRemain();
|
||||
__aicore__ inline void SyncAll();
|
||||
__aicore__ inline void AssistInit();
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, 1> copyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> copyOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> copyOutZeroQueue;
|
||||
|
||||
TQue<QuePosition::VECIN, 1> inputXInQueue;
|
||||
TQue<QuePosition::VECIN, 1> smoothInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> calcQueue;
|
||||
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> scaleOutZeroQueue;
|
||||
|
||||
GlobalTensor<int32_t> expandDstToSrcRowGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
GlobalTensor<int32_t> expertIdxValueGm;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm;
|
||||
GlobalTensor<int8_t> expandedXGm;
|
||||
|
||||
GlobalTensor<T> inputXGm;
|
||||
GlobalTensor<float> quantSmoothGm;
|
||||
GlobalTensor<float> dynamicQuantScaleGm;
|
||||
GlobalTensor<float> quantSrcGm;
|
||||
|
||||
LocalTensor<int8_t> outTmpLocal;
|
||||
LocalTensor<float> scaleOutTmpLocal;
|
||||
LocalTensor<float> smoothLocal;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
|
||||
|
||||
int64_t coreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t totalLength;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t rowLoops;
|
||||
int64_t expertCapacity;
|
||||
int64_t expertNum;
|
||||
int64_t cols;
|
||||
int64_t perLoopCols;
|
||||
int64_t lastLoopCols;
|
||||
int64_t colLoops;
|
||||
int64_t perLoopColsAlign;
|
||||
int64_t k;
|
||||
int64_t colsTileLength;
|
||||
int64_t smoothType;
|
||||
|
||||
int64_t tokenCount = 0;
|
||||
int32_t lastExpertId = -1;
|
||||
int32_t lastCoreExpertId = 0;
|
||||
int32_t lastCoreExpertIdNum = 0;
|
||||
};
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::AssistInit() {
|
||||
LocalTensor<int16_t> outLocal = copyOutZeroQueue.AllocTensor<int16_t>();
|
||||
Duplicate<int16_t>(outLocal, static_cast<int16_t>(0), this->perLoopCols);
|
||||
copyOutZeroQueue.EnQue<int16_t>(outLocal);
|
||||
LocalTensor<float> scaleOutLocal = scaleOutZeroQueue.AllocTensor<float>();
|
||||
Duplicate<float>(scaleOutLocal, 0.0f, 8);
|
||||
scaleOutZeroQueue.EnQue<float>(scaleOutLocal);
|
||||
|
||||
if (this->blockIdx != 0) {
|
||||
this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2);
|
||||
this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1);
|
||||
for (int64_t i = this->blockIdx - 2; i >= 0; i--) {
|
||||
int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2);
|
||||
if (lastExpertIdx < this->lastCoreExpertId) {
|
||||
break;
|
||||
}
|
||||
int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1);
|
||||
this->lastCoreExpertIdNum += lastExpertNum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyIn(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
|
||||
int64_t length = Align(currentLoopRows, sizeof(int32_t));
|
||||
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length);
|
||||
DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length);
|
||||
|
||||
copyInQueue.EnQue<int32_t>(inLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Compute(int32_t srcIdx, int32_t dstIdx,
|
||||
int32_t expertIdx) {
|
||||
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
|
||||
|
||||
if constexpr (IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal.template ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols],
|
||||
copyInParams, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
if (smoothType == 2) {
|
||||
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
inputXInQueue.EnQue<float>(inLocal);
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
|
||||
inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
|
||||
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(inLocal, inLocal.template ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
if (smoothType != 0) {
|
||||
Mul(inLocal, inLocal, smoothLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
Abs(tempLocal, inLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
|
||||
|
||||
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
|
||||
Duplicate<float>(tempLocal, maxValue, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Div(tempLocal, inLocal, tempLocal, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols);
|
||||
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
inputXOutQueue.EnQue(outLocal);
|
||||
scaleOutQueue.EnQue(dynamicQuantLocal);
|
||||
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
|
||||
outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
#ifndef __CCE_KT_TEST__
|
||||
DataCopyPad(expandedXGm[dstIdx * this->cols], outLocal, copyOutParams);
|
||||
#endif
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOut(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
|
||||
int64_t length = Align(currentLoopRows, sizeof(int32_t));
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
|
||||
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
if (this->lastExpertId == -1) {
|
||||
this->lastExpertId = this->lastCoreExpertId;
|
||||
this->tokenCount = this->lastCoreExpertIdNum;
|
||||
}
|
||||
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
|
||||
int32_t expertIdx = inLocal[length].GetValue(idx);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
int32_t index = 0;
|
||||
while (this->lastExpertId < expertIdx) {
|
||||
while (this->tokenCount < this->expertCapacity) {
|
||||
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
|
||||
DataCopyPad(expandedXGm[index * this->cols], this->outTmpLocal, copyParams1);
|
||||
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
|
||||
this->tokenCount++;
|
||||
}
|
||||
this->tokenCount = 0;
|
||||
this->lastExpertId++;
|
||||
}
|
||||
|
||||
if (this->tokenCount < this->expertCapacity) {
|
||||
int32_t outOffset = inLocal.GetValue(idx);
|
||||
index = expertIdx * this->expertCapacity + this->tokenCount;
|
||||
outLocal.SetValue(0, index);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
|
||||
Compute(outOffset, index, expertIdx);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
this->tokenCount++;
|
||||
}
|
||||
}
|
||||
copyInQueue.FreeTensor(inLocal);
|
||||
copyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline float MoeV2SrcToDstAndGather<T, TilingData>::ComputeMax(LocalTensor<float>& inLocal,
|
||||
LocalTensor<float>& tempLocal,
|
||||
LocalTensor<float>& dynamicQuantLocal,
|
||||
int32_t srcIdx, int32_t expertIdx,
|
||||
int64_t j) {
|
||||
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
|
||||
DataCopyExtParams intriParamsT{1, static_cast<uint32_t>(colsTileLength * sizeof(T)), 0, 0, 0};
|
||||
DataCopyExtParams intriParamsFp32{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
DataCopyPad(inLocal.ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols],
|
||||
intriParamsT, {false, 0, 0, 0});
|
||||
} else {
|
||||
DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0});
|
||||
}
|
||||
|
||||
inputXInQueue.EnQue<float>(inLocal);
|
||||
inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
if constexpr (!IsSameType<T, float>::value) {
|
||||
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
if (smoothType != 0) {
|
||||
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32,
|
||||
{false, 0, 0, 0});
|
||||
smoothInQueue.EnQue(smoothLocal);
|
||||
smoothLocal = smoothInQueue.DeQue<float>();
|
||||
|
||||
Mul(inLocal, inLocal, smoothLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
|
||||
Abs(tempLocal, inLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength);
|
||||
|
||||
DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32);
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
|
||||
return dynamicQuantLocal.GetValue(8);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::ComputeScale(LocalTensor<float>& inLocal,
|
||||
LocalTensor<float>& tempLocal,
|
||||
float scaleTemp, int64_t dstIndex,
|
||||
int64_t j) {
|
||||
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
|
||||
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(colsTileLength * sizeof(int8_t)), 0, 0, 0};
|
||||
|
||||
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
|
||||
|
||||
DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0});
|
||||
inputXInQueue.EnQue<float>(inLocal);
|
||||
inLocal = inputXInQueue.DeQue<float>();
|
||||
|
||||
Duplicate<float>(tempLocal, scaleTemp, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Div(tempLocal, inLocal, tempLocal, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, colsTileLength);
|
||||
|
||||
inputXOutQueue.EnQue(outLocal);
|
||||
outLocal = inputXOutQueue.DeQue<int8_t>();
|
||||
DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams);
|
||||
|
||||
inputXOutQueue.FreeTensor(outLocal);
|
||||
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::ComputeLoops(int32_t srcIdx, int32_t dstIdx,
|
||||
int32_t expertIdx) {
|
||||
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
|
||||
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
|
||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
|
||||
|
||||
uint32_t tmp = 0xFF7FFFFF;
|
||||
float reduceMax = *((float*)&tmp);
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j);
|
||||
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
|
||||
}
|
||||
|
||||
float scaleTemp = reduceMax / 127.0f;
|
||||
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
|
||||
scaleOutQueue.EnQue(quantScaleLocal);
|
||||
quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||
|
||||
DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0});
|
||||
|
||||
for (int64_t j = 0; j < this->colLoops; j++) {
|
||||
colsTileLength = this->perLoopCols;
|
||||
if (j == this->colLoops - 1) {
|
||||
colsTileLength = this->lastLoopCols;
|
||||
}
|
||||
ComputeScale(inLocal, tempLocal, scaleTemp, dstIdx, j);
|
||||
}
|
||||
|
||||
inputXInQueue.FreeTensor(inLocal);
|
||||
calcQueue.FreeTensor(tempLocal);
|
||||
scaleOutQueue.FreeTensor(quantScaleLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOutLoops(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
|
||||
int64_t length = Align(currentLoopRows, sizeof(int32_t));
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
|
||||
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
if (this->lastExpertId == -1) {
|
||||
this->lastExpertId = this->lastCoreExpertId;
|
||||
this->tokenCount = this->lastCoreExpertIdNum;
|
||||
}
|
||||
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
|
||||
int32_t expertIdx = inLocal[length].GetValue(idx);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
int32_t index = 0;
|
||||
while (this->lastExpertId < expertIdx) {
|
||||
while (this->tokenCount < this->expertCapacity) {
|
||||
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
|
||||
int64_t col = this->perLoopCols;
|
||||
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
|
||||
for (int64_t i = 0; i < this->colLoops; i++) {
|
||||
if (i == this->colLoops - 1) {
|
||||
col = this->lastLoopCols;
|
||||
}
|
||||
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(int8_t)), 0, 0, 0};
|
||||
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
}
|
||||
this->tokenCount++;
|
||||
}
|
||||
this->tokenCount = 0;
|
||||
this->lastExpertId++;
|
||||
}
|
||||
|
||||
if (this->tokenCount < this->expertCapacity) {
|
||||
int32_t outOffset = inLocal.GetValue(idx);
|
||||
index = expertIdx * this->expertCapacity + this->tokenCount;
|
||||
outLocal.SetValue(0, index);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
|
||||
if (smoothType == 2) {
|
||||
ComputeLoops(outOffset, index, expertIdx);
|
||||
} else {
|
||||
ComputeLoops(outOffset, index, 0);
|
||||
}
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
this->tokenCount++;
|
||||
}
|
||||
}
|
||||
copyInQueue.FreeTensor(inLocal);
|
||||
copyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOutRemain() {
|
||||
if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) {
|
||||
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
|
||||
scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal);
|
||||
return;
|
||||
}
|
||||
while (this->lastExpertId < this->expertNum) {
|
||||
while (this->tokenCount < this->expertCapacity) {
|
||||
int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount;
|
||||
int64_t col = this->perLoopCols;
|
||||
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
|
||||
for (int64_t i = 0; i < this->colLoops; i++) {
|
||||
if (i == this->colLoops - 1) {
|
||||
col = this->lastLoopCols;
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(int8_t)), 0, 0, 0};
|
||||
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
}
|
||||
this->tokenCount++;
|
||||
}
|
||||
this->tokenCount = 0;
|
||||
this->lastExpertId++;
|
||||
}
|
||||
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
|
||||
scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx,
|
||||
GM_ADDR expandedX, GM_ADDR dynamicQuantScale,
|
||||
GM_ADDR workspace, const TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
int64_t blockNum = GetBlockNum();
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp);
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->expertCapacity = tilingData->expertCapacity;
|
||||
this->cols = tilingData->cols;
|
||||
this->k = tilingData->k;
|
||||
this->smoothType = tilingData->smoothType;
|
||||
|
||||
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->srcToDstTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
|
||||
this->rowLoops = this->srcToDstTilingData->lastCoreLoops;
|
||||
} else {
|
||||
this->coreRows = this->srcToDstTilingData->perCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
|
||||
this->rowLoops = this->srcToDstTilingData->perCoreLoops;
|
||||
}
|
||||
this->perLoopCols = this->srcToDstTilingData->perLoopCols;
|
||||
this->lastLoopCols = this->srcToDstTilingData->lastLoopCols;
|
||||
this->colLoops = this->srcToDstTilingData->colLoops;
|
||||
this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T));
|
||||
|
||||
inputXGm.SetGlobalBuffer((__gm__ T*)x);
|
||||
quantSmoothGm.SetGlobalBuffer((__gm__ float*)scale);
|
||||
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
|
||||
|
||||
int64_t length = Align(this->totalLength, sizeof(int32_t));
|
||||
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length);
|
||||
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX, this->expertNum * this->expertCapacity * this->cols);
|
||||
|
||||
expandedExpertIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
expandDstToSrcRowGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2);
|
||||
if (this->colLoops > 1) {
|
||||
quantSrcGm.SetGlobalBuffer((__gm__ float*)workspace + length * 2 + this->coreNum * 2 + this->blockIdx * this->cols,
|
||||
this->cols * sizeof(float));
|
||||
}
|
||||
|
||||
pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2);
|
||||
pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t)));
|
||||
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t)));
|
||||
|
||||
int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T));
|
||||
perLoopColsAlignBytes =
|
||||
Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES));
|
||||
|
||||
pipe->InitBuffer(inputXInQueue, 1, perLoopColsAlignBytes);
|
||||
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
|
||||
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
|
||||
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
|
||||
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
|
||||
pipe->InitBuffer(scaleOutZeroQueue, 1, BLOCK_BYTES);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Process() {
|
||||
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
|
||||
AssistInit();
|
||||
this->outTmpLocal = copyOutZeroQueue.DeQue<int8_t>();
|
||||
this->scaleOutTmpLocal = scaleOutZeroQueue.DeQue<float>();
|
||||
currentLoopRows = perLoopRows;
|
||||
if (colLoops > 1) {
|
||||
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
|
||||
if (loop == this->rowLoops - 1) {
|
||||
currentLoopRows = lastLoopRows;
|
||||
}
|
||||
CopyIn(loop);
|
||||
CopyOutLoops(loop);
|
||||
}
|
||||
} else {
|
||||
smoothLocal = smoothInQueue.AllocTensor<float>();
|
||||
if (smoothType == 1) {
|
||||
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
|
||||
DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0});
|
||||
}
|
||||
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
|
||||
if (loop == this->rowLoops - 1) {
|
||||
currentLoopRows = lastLoopRows;
|
||||
}
|
||||
CopyIn(loop);
|
||||
CopyOut(loop);
|
||||
}
|
||||
smoothInQueue.FreeTensor(smoothLocal);
|
||||
}
|
||||
CopyOutRemain();
|
||||
}
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // MOE_V2_SRC_TO_DST_AND_GATHER_H
|
||||
@@ -0,0 +1,164 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_src_to_dst_op.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_SRC_TO_DST_H
|
||||
#define INNER_MOE_V2_SRC_TO_DST_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
class MoeV2SrcToDstOp {
|
||||
public:
|
||||
__aicore__ inline MoeV2SrcToDstOp(){};
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(int64_t progress);
|
||||
__aicore__ inline void Compute(int64_t progress);
|
||||
__aicore__ inline void CopyOut();
|
||||
__aicore__ inline void SyncAll();
|
||||
__aicore__ inline void AssistInit();
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, 1> copyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> copyOutQueue;
|
||||
TBuf<TPosition::VECCALC> assistBuffer;
|
||||
|
||||
GlobalTensor<int32_t> expandDstToSrcRowGm;
|
||||
GlobalTensor<int32_t> expandSrcToDstRowGm;
|
||||
GlobalTensor<int32_t> assistGm;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
|
||||
|
||||
int64_t coreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t totalLength;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
};
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::AssistInit() {
|
||||
#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1
|
||||
OOMCheckAddrRange(assistGm.GetPhyAddr(), ASSIST_NUM * sizeof(int32_t));
|
||||
#endif
|
||||
LocalTensor<int32_t> assistTensor = assistBuffer.Get<int32_t>(ASSIST_NUM);
|
||||
DataCopy(assistTensor, assistGm, ASSIST_NUM);
|
||||
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
|
||||
Adds(assistTensor, assistTensor, (int32_t)(this->blockIdx * this->srcToDstTilingData->perCoreRows), ASSIST_NUM);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::CopyIn(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
|
||||
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t)));
|
||||
copyInQueue.EnQue<int32_t>(inLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::Compute(int64_t progress) {
|
||||
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
|
||||
LocalTensor<int32_t> assistTensor = assistBuffer.Get<int32_t>(ASSIST_NUM);
|
||||
|
||||
pipe_barrier(PIPE_V);
|
||||
int64_t loops = Ceil(currentLoopRows, ASSIST_INDEX_NUM);
|
||||
for (int64_t i = 0; i < loops; i++) {
|
||||
Adds(outLocal[i * ASSIST_NUM], assistTensor,
|
||||
static_cast<int32_t>(this->perLoopRows * progress + i * ASSIST_INDEX_NUM), ASSIST_NUM);
|
||||
}
|
||||
pipe_barrier(PIPE_V);
|
||||
copyOutQueue.EnQue<int32_t>(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::CopyOut() {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> outLocal = copyOutQueue.DeQue<int32_t>();
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
DataCopyParams intriParams;
|
||||
intriParams.blockCount = 1;
|
||||
intriParams.blockLen = sizeof(int32_t);
|
||||
uint32_t outOffset;
|
||||
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
|
||||
outOffset = inLocal.GetValue(idx);
|
||||
DataCopyPad(expandSrcToDstRowGm[outOffset], outLocal[idx * INT32_ONE_BLOCK_NUM], intriParams);
|
||||
}
|
||||
|
||||
copyInQueue.FreeTensor(inLocal);
|
||||
copyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::SyncAll() {
|
||||
if (coreNum == 1) {
|
||||
return;
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstOp::Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
int64_t blockNum = GetBlockNum();
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp);
|
||||
|
||||
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->srcToDstTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
|
||||
} else {
|
||||
this->coreRows = this->srcToDstTilingData->perCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
|
||||
}
|
||||
|
||||
expandSrcToDstRowGm.SetGlobalBuffer((__gm__ int32_t*)expandSrcToDstRow, Align(this->totalLength, sizeof(int32_t)));
|
||||
expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) +
|
||||
this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
assistGm.SetGlobalBuffer((__gm__ int32_t*)assist, ASSIST_NUM);
|
||||
|
||||
pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES);
|
||||
pipe->InitBuffer(copyOutQueue, 1, Ceil(this->perLoopRows, ASSIST_NUM) * ASSIST_NUM * BLOCK_BYTES);
|
||||
pipe->InitBuffer(assistBuffer, ASSIST_NUM * sizeof(int32_t));
|
||||
}
|
||||
|
||||
__aicore__ inline void MoeV2SrcToDstOp::Process() {
|
||||
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
|
||||
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
|
||||
currentLoopRows = perLoopRows;
|
||||
AssistInit();
|
||||
for (int64_t loop = 0; loop < loops - 1; loop++) {
|
||||
CopyIn(loop);
|
||||
Compute(loop);
|
||||
CopyOut();
|
||||
}
|
||||
currentLoopRows = lastLoopRows;
|
||||
CopyIn(loops - 1);
|
||||
Compute(loops - 1);
|
||||
CopyOut();
|
||||
}
|
||||
this->SyncAll();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_SRC_TO_DST_H
|
||||
@@ -0,0 +1,269 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_v2_src_to_dst_with_capacity.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
|
||||
#define INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
|
||||
|
||||
#include "moe_v2_common.h"
|
||||
|
||||
namespace MoeInitRoutingQuantV2 {
|
||||
using namespace AscendC;
|
||||
using namespace optiling;
|
||||
template <typename T, typename TilingData>
|
||||
class MoeV2SrcToDstWithCapacity {
|
||||
public:
|
||||
__aicore__ inline MoeV2SrcToDstWithCapacity(){};
|
||||
__aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace,
|
||||
const TilingData* tilingData, TPipe* tPipe);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(int64_t progress);
|
||||
__aicore__ inline void CopyOut(int64_t progress);
|
||||
__aicore__ inline void CopyOutRemain();
|
||||
__aicore__ inline void SyncAll();
|
||||
__aicore__ inline void AssistInit();
|
||||
|
||||
private:
|
||||
TPipe* pipe;
|
||||
TQue<QuePosition::VECIN, 1> copyInQueue;
|
||||
TQue<QuePosition::VECOUT, 1> copyOutQueue;
|
||||
TQue<QuePosition::VECOUT, 1> copyOutZeroQueue;
|
||||
|
||||
GlobalTensor<int32_t> expandDstToSrcRowGm;
|
||||
GlobalTensor<int32_t> expandedRowIdxGm;
|
||||
GlobalTensor<int32_t> expertIdxValueGm;
|
||||
GlobalTensor<int32_t> expandedExpertIdxGm;
|
||||
GlobalTensor<T> expandedXGm;
|
||||
|
||||
LocalTensor<T> outTmpLocal;
|
||||
|
||||
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
|
||||
|
||||
int64_t coreNum;
|
||||
int64_t blockIdx;
|
||||
int64_t totalLength;
|
||||
int64_t currentLoopRows;
|
||||
int64_t coreRows;
|
||||
int64_t perLoopRows;
|
||||
int64_t lastLoopRows;
|
||||
int64_t rowLoops;
|
||||
int64_t expertCapacity;
|
||||
int64_t expertNum;
|
||||
int64_t cols;
|
||||
int64_t perLoopCols;
|
||||
int64_t lastLoopCols;
|
||||
int64_t colLoops;
|
||||
|
||||
int64_t tokenCount = 0;
|
||||
int32_t lastExpertId = -1;
|
||||
int32_t lastCoreExpertId = 0;
|
||||
int32_t lastCoreExpertIdNum = 0;
|
||||
};
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::AssistInit() {
|
||||
if constexpr (IsSameType<T, int8_t>::value) {
|
||||
LocalTensor<int16_t> outLocal = copyOutZeroQueue.AllocTensor<int16_t>();
|
||||
Duplicate<int16_t>(outLocal, static_cast<int16_t>(0), this->perLoopCols);
|
||||
copyOutZeroQueue.EnQue<int16_t>(outLocal);
|
||||
} else {
|
||||
LocalTensor<T> outLocal = copyOutZeroQueue.AllocTensor<T>();
|
||||
Duplicate<T>(outLocal, static_cast<T>(0), this->perLoopCols);
|
||||
copyOutZeroQueue.EnQue<T>(outLocal);
|
||||
}
|
||||
|
||||
if (this->blockIdx != 0) {
|
||||
this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2);
|
||||
this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1);
|
||||
for (int64_t i = this->blockIdx - 2; i >= 0; i--) {
|
||||
int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2);
|
||||
if (lastExpertIdx < this->lastCoreExpertId) {
|
||||
break;
|
||||
}
|
||||
int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1);
|
||||
this->lastCoreExpertIdNum += lastExpertNum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyIn(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
|
||||
int64_t length = Align(currentLoopRows, sizeof(int32_t));
|
||||
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length);
|
||||
DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length);
|
||||
copyInQueue.EnQue<int32_t>(inLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyOut(int64_t progress) {
|
||||
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
|
||||
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
|
||||
int64_t length = Align(currentLoopRows, sizeof(int32_t));
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
|
||||
|
||||
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
|
||||
if (this->lastExpertId == -1) {
|
||||
this->lastExpertId = this->lastCoreExpertId;
|
||||
this->tokenCount = this->lastCoreExpertIdNum;
|
||||
}
|
||||
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
|
||||
int32_t expertIdx = inLocal[length].GetValue(idx);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
int32_t index = 0;
|
||||
while (this->lastExpertId < expertIdx) {
|
||||
while (this->tokenCount < this->expertCapacity) {
|
||||
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
|
||||
int64_t col = this->perLoopCols;
|
||||
for (int64_t i = 0; i < this->colLoops; i++) {
|
||||
if (i == this->colLoops - 1) {
|
||||
col = this->lastLoopCols;
|
||||
}
|
||||
#ifdef __CCE_KT_TEST__
|
||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
||||
if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
}
|
||||
this->tokenCount++;
|
||||
}
|
||||
this->tokenCount = 0;
|
||||
this->lastExpertId++;
|
||||
}
|
||||
|
||||
if (this->tokenCount < this->expertCapacity) {
|
||||
int32_t outOffset = inLocal.GetValue(idx);
|
||||
index = expertIdx * this->expertCapacity + this->tokenCount;
|
||||
outLocal.SetValue(0, index);
|
||||
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
|
||||
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
this->tokenCount++;
|
||||
}
|
||||
}
|
||||
copyInQueue.FreeTensor(inLocal);
|
||||
copyOutQueue.FreeTensor(outLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyOutRemain() {
|
||||
if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) {
|
||||
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
|
||||
return;
|
||||
}
|
||||
while (this->lastExpertId < this->expertNum) {
|
||||
while (this->tokenCount < this->expertCapacity) {
|
||||
int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount;
|
||||
int64_t col = this->perLoopCols;
|
||||
for (int64_t i = 0; i < this->colLoops; i++) {
|
||||
if (i == this->colLoops - 1) {
|
||||
col = this->lastLoopCols;
|
||||
}
|
||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(T)), 0, 0, 0};
|
||||
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams);
|
||||
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
|
||||
}
|
||||
this->tokenCount++;
|
||||
}
|
||||
this->tokenCount = 0;
|
||||
this->lastExpertId++;
|
||||
}
|
||||
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::SyncAll() {
|
||||
if (coreNum == 1) {
|
||||
return;
|
||||
}
|
||||
#ifndef __CCE_KT_TEST__
|
||||
AscendC::SyncAll();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX,
|
||||
GM_ADDR workspace, const TilingData* tilingData,
|
||||
TPipe* tPipe) {
|
||||
int64_t blockNum = GetBlockNum();
|
||||
this->pipe = tPipe;
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
|
||||
this->coreNum = tilingData->coreNum;
|
||||
this->totalLength = tilingData->n * tilingData->k;
|
||||
this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp);
|
||||
this->expertNum = tilingData->expertNum;
|
||||
this->expertCapacity = tilingData->expertCapacity;
|
||||
this->cols = tilingData->cols;
|
||||
|
||||
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
|
||||
this->coreRows = this->srcToDstTilingData->lastCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
|
||||
this->rowLoops = this->srcToDstTilingData->lastCoreLoops;
|
||||
} else {
|
||||
this->coreRows = this->srcToDstTilingData->perCoreRows;
|
||||
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
|
||||
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
|
||||
this->rowLoops = this->srcToDstTilingData->perCoreLoops;
|
||||
}
|
||||
this->perLoopCols = this->srcToDstTilingData->perLoopCols;
|
||||
this->lastLoopCols = this->srcToDstTilingData->lastLoopCols;
|
||||
this->colLoops = this->srcToDstTilingData->colLoops;
|
||||
|
||||
int64_t length = Align(this->totalLength, sizeof(int32_t));
|
||||
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length);
|
||||
expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, this->expertNum * this->expertCapacity * this->cols);
|
||||
|
||||
expandedExpertIdxGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
expandDstToSrcRowGm.SetGlobalBuffer(
|
||||
(__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows,
|
||||
Align(this->coreRows, sizeof(int32_t)));
|
||||
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2);
|
||||
|
||||
pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2);
|
||||
pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t)));
|
||||
if constexpr (IsSameType<T, int8_t>::value) {
|
||||
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t)));
|
||||
} else {
|
||||
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(T)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TilingData>
|
||||
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::Process() {
|
||||
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
|
||||
AssistInit();
|
||||
this->outTmpLocal = copyOutZeroQueue.DeQue<T>();
|
||||
currentLoopRows = perLoopRows;
|
||||
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
|
||||
if (loop == this->rowLoops - 1) {
|
||||
currentLoopRows = lastLoopRows;
|
||||
}
|
||||
CopyIn(loop);
|
||||
CopyOut(loop);
|
||||
}
|
||||
CopyOutRemain();
|
||||
}
|
||||
this->SyncAll();
|
||||
}
|
||||
} // namespace MoeInitRoutingQuantV2
|
||||
#endif // INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
|
||||
@@ -0,0 +1,66 @@
|
||||
#pragma once
|
||||
namespace optiling {
|
||||
struct AiCoreParams {
|
||||
uint64_t ubSize;
|
||||
uint64_t blockDim;
|
||||
uint64_t aicNum;
|
||||
|
||||
uint64_t l1Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
};
|
||||
|
||||
class TilingBaseClass {
|
||||
public:
|
||||
bool DoTiling(
|
||||
int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0,
|
||||
int64_t aivCoreNum, int64_t ubSizePlatForm)
|
||||
{
|
||||
bool ret = GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode, expertTokensCountOrCumsumFlag,
|
||||
expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0);
|
||||
|
||||
if (!ret){
|
||||
return ret;
|
||||
}
|
||||
ret = GetPlatformInfo(aivCoreNum, ubSizePlatForm);
|
||||
if (!ret){
|
||||
return ret;
|
||||
}
|
||||
ret = DoOpTiling();
|
||||
if (!ret){
|
||||
return ret;
|
||||
}
|
||||
ret = GetWorkspaceSize();
|
||||
if (!ret){
|
||||
return ret;
|
||||
}
|
||||
ret = PostTiling();
|
||||
if (!ret){
|
||||
return ret;
|
||||
}
|
||||
tilingKey_ = GetTilingKey();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//protected:
|
||||
virtual bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) = 0;
|
||||
virtual bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
|
||||
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
|
||||
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) = 0;
|
||||
|
||||
virtual bool DoOpTiling() = 0;
|
||||
virtual bool GetWorkspaceSize() = 0;
|
||||
virtual bool PostTiling() = 0;
|
||||
virtual uint64_t GetTilingKey() const = 0;
|
||||
//protected:
|
||||
uint32_t blockDim_{0};
|
||||
uint64_t workspaceSize_{0};
|
||||
uint64_t tilingKey_{0};
|
||||
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,376 @@
|
||||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_token_unpermute.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef MOE_TOKEN_UNPERMUTE
|
||||
#define MOE_TOKEN_UNPERMUTE
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "moe_token_unpermute_tiling.h"
|
||||
using namespace AscendC;
|
||||
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS> class KernelMoeTokenUnpermute {
|
||||
public:
|
||||
__aicore__ inline KernelMoeTokenUnpermute()
|
||||
{
|
||||
}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs,
|
||||
GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
protected:
|
||||
__aicore__ inline void CalMultiOutToken(const int64_t out_offset, const int64_t out_tokens_number);
|
||||
__aicore__ inline void CalSingleOutToken(const int64_t start_token, const int64_t out_token_idx);
|
||||
__aicore__ inline void CalPartOutToken(const int64_t start_token, const int64_t h_index, const int64_t h_length,
|
||||
const int64_t out_token_index);
|
||||
__aicore__ inline void CopyTokenIn(const T2 in_token_index, const int64_t h_index, const int64_t h_length);
|
||||
__aicore__ inline void CalFirstToken(const float prob_value, const int64_t h_length);
|
||||
__aicore__ inline void CalToken(const float prob_value, const int64_t h_length);
|
||||
__aicore__ inline void CopyOut(const int64_t out_token_index, const int64_t h_index, const int64_t h_length);
|
||||
|
||||
TPipe pipe;
|
||||
TQue<QuePosition::VECIN, 1> tokens_inque, indices_inque, probs_inque;
|
||||
TBuf<TPosition::VECCALC> temp_buffer0, temp_buffer1, temp_buffer2;
|
||||
TQue<QuePosition::VECOUT, 1> outque;
|
||||
GlobalTensor<T1> tokensGM, outGM;
|
||||
GlobalTensor<T2> indicesGM;
|
||||
GlobalTensor<T3> probsGM;
|
||||
LocalTensor<T2> indicesLocal;
|
||||
LocalTensor<float> token_tensor0, token_tensor1, probs_tensor;
|
||||
DataCopyPadExtParams<T1> extParams1{false, 0, 0, 0};
|
||||
DataCopyPadExtParams<T2> extParams2{false, 0, 0, 0};
|
||||
DataCopyPadExtParams<T3> extParams3{false, 0, 0, 0};
|
||||
DataCopyExtParams copyParams{1, 0, 0, 0, 0};
|
||||
|
||||
constexpr static uint32_t BLOCK_SIZE = 32;
|
||||
constexpr static uint32_t ALIGN_512 = 512;
|
||||
|
||||
int64_t hidden_size;
|
||||
int64_t top_k;
|
||||
int64_t num_out_tokens;
|
||||
int64_t hidden_splited_length;
|
||||
int64_t hidden_splited_num;
|
||||
int64_t hidden_splited_remain;
|
||||
int64_t tokens_core_length;
|
||||
int64_t tokens_core_remain;
|
||||
int64_t tokens_splited_length;
|
||||
int64_t tokens_splited_num;
|
||||
int64_t tokens_splited_remain;
|
||||
int32_t blockIdx;
|
||||
int32_t blockNum;
|
||||
};
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void
|
||||
KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs,
|
||||
GM_ADDR unpermuted_tokens,
|
||||
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
|
||||
{
|
||||
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
this->blockNum = get_block_num() * get_subblockdim();
|
||||
|
||||
if (blockIdx >= blockNum) {
|
||||
return;
|
||||
}
|
||||
ASSERT(blockNum != 0 && "block dim can not be zero!");
|
||||
// row_input
|
||||
this->hidden_size = tiling_data->hidden_size;
|
||||
this->top_k = tiling_data->top_k;
|
||||
this->num_out_tokens = tiling_data->num_out_tokens;
|
||||
// hidden_tiling
|
||||
this->hidden_splited_length = tiling_data->hidden_splited_length;
|
||||
this->hidden_splited_num = tiling_data->hidden_splited_num;
|
||||
this->hidden_splited_remain = tiling_data->hidden_splited_remain;
|
||||
// token_tiling
|
||||
this->tokens_core_length = tiling_data->tokens_core_length;
|
||||
this->tokens_core_remain = tiling_data->tokens_core_remain;
|
||||
this->tokens_splited_length = tiling_data->tokens_splited_length;
|
||||
this->tokens_splited_num = tiling_data->tokens_splited_num;
|
||||
this->tokens_splited_remain = tiling_data->tokens_splited_remain;
|
||||
|
||||
// 处理token_by_core尾块
|
||||
if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) {
|
||||
this->tokens_core_length += 1;
|
||||
this->tokens_splited_remain += 1;
|
||||
}
|
||||
|
||||
int64_t hidden_splited_length_align512 = (this->hidden_splited_length + ALIGN_512 - 1) & ~(ALIGN_512 - 1);
|
||||
|
||||
int64_t block_length = this->tokens_core_length * this->top_k;
|
||||
int64_t block_splited_length = this->tokens_splited_length * this->top_k;
|
||||
|
||||
int64_t block_offset;
|
||||
if (this->tokens_core_remain > 0) {
|
||||
if (blockIdx < this->tokens_core_remain) {
|
||||
block_offset = block_length * blockIdx;
|
||||
} else {
|
||||
block_offset = (block_length + this->top_k) * this->tokens_core_remain +
|
||||
block_length * (blockIdx - this->tokens_core_remain);
|
||||
}
|
||||
} else {
|
||||
block_offset = block_length * blockIdx;
|
||||
}
|
||||
|
||||
this->tokensGM.SetGlobalBuffer((__gm__ T1 *)permuted_tokens);
|
||||
this->indicesGM.SetGlobalBuffer((__gm__ T2 *)sorted_indices + block_offset, block_length);
|
||||
|
||||
|
||||
int64_t out_block_offset;
|
||||
if (this->tokens_core_remain > 0) {
|
||||
if (blockIdx < this->tokens_core_remain) {
|
||||
out_block_offset = this->tokens_core_length * blockIdx * hidden_size;
|
||||
} else {
|
||||
out_block_offset = (this->tokens_core_length + 1) * this->tokens_core_remain +
|
||||
this->tokens_core_length * (blockIdx - this->tokens_core_remain);
|
||||
out_block_offset *= this->hidden_size;
|
||||
}
|
||||
} else {
|
||||
out_block_offset = this->tokens_core_length * blockIdx * hidden_size;
|
||||
}
|
||||
|
||||
this->outGM.SetGlobalBuffer((__gm__ T1 *)unpermuted_tokens + out_block_offset,
|
||||
this->tokens_core_length * this->hidden_size);
|
||||
|
||||
this->pipe.InitBuffer(tokens_inque, tiling_data->buffer_num, hidden_splited_length_align512 * sizeof(T1));
|
||||
this->pipe.InitBuffer(indices_inque, 1, block_splited_length * (sizeof(T2)));
|
||||
this->pipe.InitBuffer(outque, 1, hidden_splited_length_align512 * sizeof(T1));
|
||||
|
||||
if constexpr (!IsSameType<T1, float>::value) {
|
||||
this->pipe.InitBuffer(temp_buffer0, hidden_splited_length_align512 * sizeof(float) + 256);
|
||||
this->pipe.InitBuffer(temp_buffer1, hidden_splited_length_align512 * sizeof(float));
|
||||
this->token_tensor0 = this->temp_buffer0.template Get<float>();
|
||||
this->token_tensor1 = this->temp_buffer1.template Get<float>();
|
||||
}
|
||||
|
||||
if constexpr (PROBS) {
|
||||
this->probsGM.SetGlobalBuffer((__gm__ T3 *)probs + block_offset, block_length);
|
||||
this->pipe.InitBuffer(probs_inque, 1, block_splited_length * (sizeof(T3)));
|
||||
if constexpr (!IsSameType<T3, float>::value) {
|
||||
this->pipe.InitBuffer(temp_buffer2, block_splited_length * sizeof(float));
|
||||
this->probs_tensor = this->temp_buffer2.template Get<float>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Process()
|
||||
{
|
||||
|
||||
if (blockIdx >= blockNum) {
|
||||
return;
|
||||
}
|
||||
for (int64_t i = 0; i < this->tokens_splited_num; ++i) {
|
||||
CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length);
|
||||
}
|
||||
// 处理tokens_num不能均匀分核数的尾块
|
||||
if (this->tokens_splited_remain > 0) {
|
||||
CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalMultiOutToken(const int64_t out_offset,
|
||||
const int64_t out_tokens_number)
|
||||
{
|
||||
this->indicesLocal = this->indices_inque.template AllocTensor<T2>();
|
||||
int64_t in_offset = out_offset * this->top_k;
|
||||
this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T2);
|
||||
DataCopyPad(this->indicesLocal, this->indicesGM[in_offset], this->copyParams, this->extParams2);
|
||||
this->indices_inque.template EnQue(this->indicesLocal);
|
||||
|
||||
if constexpr (PROBS) {
|
||||
LocalTensor<T3> temp_probs_tensor = this->probs_inque.template AllocTensor<T3>();
|
||||
this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T3);
|
||||
DataCopyPad(temp_probs_tensor, this->probsGM[in_offset], this->copyParams, this->extParams3);
|
||||
this->probs_inque.template EnQue(temp_probs_tensor);
|
||||
temp_probs_tensor = this->probs_inque.template DeQue<T3>();
|
||||
if constexpr (!IsSameType<T3, float>::value) {
|
||||
Cast(this->probs_tensor, temp_probs_tensor, RoundMode::CAST_NONE, out_tokens_number * this->top_k);
|
||||
this->probs_inque.FreeTensor(temp_probs_tensor);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
this->probs_tensor = temp_probs_tensor;
|
||||
}
|
||||
}
|
||||
this->indicesLocal = this->indices_inque.template DeQue<T2>();
|
||||
|
||||
|
||||
for (int64_t out_token_idx = 0; out_token_idx < out_tokens_number; ++out_token_idx) {
|
||||
CalSingleOutToken(out_token_idx * this->top_k, out_offset + out_token_idx);
|
||||
}
|
||||
// Free Tensor
|
||||
this->indices_inque.FreeTensor(this->indicesLocal);
|
||||
if constexpr (PROBS && IsSameType<T3, float>::value) {
|
||||
this->probs_inque.FreeTensor(this->probs_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalSingleOutToken(const int64_t start_token,
|
||||
const int64_t out_token_idx)
|
||||
{
|
||||
for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) {
|
||||
CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx);
|
||||
}
|
||||
// 一次不能完整容纳完整的hidden_size, 处理尾块
|
||||
if (this->hidden_splited_remain > 0) {
|
||||
CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void
|
||||
KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalPartOutToken(const int64_t start_token, const int64_t h_index,
|
||||
const int64_t h_length, const int64_t out_token_index)
|
||||
{
|
||||
if constexpr (IsSameType<T1, float>::value) {
|
||||
this->token_tensor0 = this->outque.template AllocTensor<T1>();
|
||||
}
|
||||
int64_t end_token = start_token + this->top_k;
|
||||
T2 cal_token_idx = this->indicesLocal.GetValue(start_token);
|
||||
|
||||
// 处理第一个Token数据
|
||||
if (cal_token_idx < this->num_out_tokens) {
|
||||
float probsValue = 0;
|
||||
if constexpr (PROBS) {
|
||||
probsValue = this->probs_tensor.GetValue(start_token);
|
||||
}
|
||||
|
||||
CopyTokenIn(cal_token_idx, h_index, h_length);
|
||||
PipeBarrier<PIPE_V>();
|
||||
CalFirstToken(probsValue, h_length);
|
||||
} else {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Duplicate(this->token_tensor0, static_cast<float>(0), h_length);
|
||||
}
|
||||
|
||||
// 处理剩余的Token数据
|
||||
for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) {
|
||||
cal_token_idx = this->indicesLocal.GetValue(token_index);
|
||||
if (cal_token_idx < this->num_out_tokens) {
|
||||
float probsValue = 0;
|
||||
if constexpr (PROBS) {
|
||||
probsValue = this->probs_tensor.GetValue(token_index);
|
||||
}
|
||||
|
||||
CopyTokenIn(cal_token_idx, h_index, h_length);
|
||||
PipeBarrier<PIPE_V>();
|
||||
CalToken(probsValue, h_length);
|
||||
}
|
||||
}
|
||||
|
||||
// 输出计算结果
|
||||
CopyOut(out_token_index, h_index, h_length);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CopyTokenIn(const T2 in_token_index,
|
||||
const int64_t h_index,
|
||||
const int64_t h_length)
|
||||
{
|
||||
LocalTensor<T1> tokensLocal = this->tokens_inque.template AllocTensor<T1>();
|
||||
int64_t offset = in_token_index * this->hidden_size + h_index * this->hidden_splited_length;
|
||||
|
||||
if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) {
|
||||
DataCopy(tokensLocal, this->tokensGM[offset], h_length);
|
||||
} else {
|
||||
this->copyParams.blockLen = h_length * sizeof(T1);
|
||||
DataCopyPad(tokensLocal, this->tokensGM[offset], this->copyParams, this->extParams1);
|
||||
}
|
||||
|
||||
this->tokens_inque.template EnQue(tokensLocal);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalFirstToken(const float prob_value,
|
||||
const int64_t h_length)
|
||||
{
|
||||
LocalTensor<T1> tokensLocal = this->tokens_inque.template DeQue<T1>();
|
||||
|
||||
if constexpr (!IsSameType<T1, float>::value) {
|
||||
Cast(this->token_tensor0, tokensLocal, RoundMode::CAST_NONE, h_length);
|
||||
} else {
|
||||
uint64_t byteAlign32 = (h_length * sizeof(float) + BLOCK_SIZE - 1) & ~(BLOCK_SIZE - 1);
|
||||
DataCopy(this->token_tensor0, tokensLocal, byteAlign32 / sizeof(float));
|
||||
}
|
||||
|
||||
this->tokens_inque.FreeTensor(tokensLocal);
|
||||
|
||||
if constexpr (PROBS) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(this->token_tensor0, this->token_tensor0, prob_value, h_length);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalToken(const float prob_value,
|
||||
const int64_t h_length)
|
||||
{
|
||||
LocalTensor<T1> tokensLocal = this->tokens_inque.template DeQue<T1>();
|
||||
|
||||
if constexpr (!IsSameType<T1, float>::value) {
|
||||
Cast(this->token_tensor1, tokensLocal, RoundMode::CAST_NONE, h_length);
|
||||
this->tokens_inque.FreeTensor(tokensLocal);
|
||||
if constexpr (PROBS) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(this->token_tensor1, this->token_tensor1, prob_value, h_length);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(this->token_tensor0, this->token_tensor0, this->token_tensor1, h_length);
|
||||
} else {
|
||||
if constexpr (PROBS) {
|
||||
Muls(tokensLocal, tokensLocal, prob_value, h_length);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Add(this->token_tensor0, this->token_tensor0, tokensLocal, h_length);
|
||||
this->tokens_inque.FreeTensor(tokensLocal);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, bool PROBS>
|
||||
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CopyOut(const int64_t out_token_index,
|
||||
const int64_t h_index,
|
||||
const int64_t h_length)
|
||||
{
|
||||
LocalTensor<T1> temp_out_tensors;
|
||||
if constexpr (!IsSameType<T1, float>::value) {
|
||||
temp_out_tensors = this->outque.template AllocTensor<T1>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(temp_out_tensors, this->token_tensor0, RoundMode::CAST_RINT, h_length);
|
||||
} else {
|
||||
temp_out_tensors = this->token_tensor0;
|
||||
}
|
||||
|
||||
this->outque.template EnQue<T1>(temp_out_tensors);
|
||||
temp_out_tensors = this->outque.template DeQue<T1>();
|
||||
|
||||
int64_t offset = out_token_index * this->hidden_size + h_index * this->hidden_splited_length;
|
||||
if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) {
|
||||
DataCopy(this->outGM[offset], temp_out_tensors, h_length);
|
||||
} else {
|
||||
this->copyParams.blockLen = h_length * sizeof(T1);
|
||||
DataCopyPad(this->outGM[offset], temp_out_tensors, this->copyParams);
|
||||
}
|
||||
|
||||
this->outque.FreeTensor(temp_out_tensors);
|
||||
}
|
||||
#endif // MOE_TOKEN_UNPERMUTE
|
||||
@@ -0,0 +1,38 @@
|
||||
#ifndef MOE_TOKEN_UNPERMUTE_TILING
|
||||
#define MOE_TOKEN_UNPERMUTE_TILING
|
||||
|
||||
struct MoeTokenUnpermuteTilingData {
|
||||
int64_t hidden_size;
|
||||
int64_t top_k;
|
||||
int64_t num_out_tokens;
|
||||
int64_t hidden_splited_length;
|
||||
int64_t hidden_splited_num;
|
||||
int64_t hidden_splited_remain;
|
||||
int64_t tokens_core_length;
|
||||
int64_t tokens_core_remain;
|
||||
int64_t tokens_splited_length;
|
||||
int64_t tokens_splited_num;
|
||||
int64_t tokens_splited_remain;
|
||||
int64_t buffer_num;
|
||||
};
|
||||
|
||||
__forceinline__ [host, aicore] void
|
||||
MoeTokenUnpermuteTiling(int32_t m, int32_t n, int32_t topK, MoeTokenUnpermuteTilingData &tilingData, uint32_t coreNum)
|
||||
{
|
||||
#define I64(x) static_cast<int64_t>(x)
|
||||
tilingData.hidden_size = I64(n);
|
||||
tilingData.top_k = I64(topK);
|
||||
tilingData.num_out_tokens = I64(m);
|
||||
tilingData.hidden_splited_length = tilingData.hidden_size;
|
||||
tilingData.hidden_splited_num = 1;
|
||||
tilingData.hidden_splited_remain = 0;
|
||||
uint32_t outTokens = m / topK;
|
||||
tilingData.tokens_core_length = I64(outTokens / coreNum);
|
||||
tilingData.tokens_core_remain = I64(outTokens % coreNum);
|
||||
tilingData.tokens_splited_length = I64(min(tilingData.tokens_core_length, 600));
|
||||
tilingData.tokens_splited_num = I64(tilingData.tokens_core_length / tilingData.tokens_splited_length);
|
||||
tilingData.tokens_splited_remain = I64(tilingData.tokens_core_length % tilingData.tokens_splited_length);
|
||||
tilingData.buffer_num = 4;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,207 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/epilogue/block/block_epilogue.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
// float scale, dequant per expert
|
||||
template <
|
||||
uint32_t UB_STAGES_,
|
||||
class CType_,
|
||||
class LayoutPerTokenScale_,
|
||||
class DType_,
|
||||
class TileCopy_
|
||||
>
|
||||
class BlockEpilogue <
|
||||
EpilogueAtlasA2PerTokenDequant<UB_STAGES_>,
|
||||
CType_,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||
DType_,
|
||||
TileCopy_
|
||||
> {
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequant<UB_STAGES_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(
|
||||
std::is_same_v<ElementC, half> && (std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
|
||||
"The element type template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
static_assert(
|
||||
std::is_same_v<LayoutC, layout::RowMajor> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
struct Params {
|
||||
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
||||
int32_t EP;
|
||||
int32_t expertPerRank;
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 4096;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr int32_t blockN = 12000;
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementC);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementD);
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += blockN * sizeof(float);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
MatrixCoord const &shapeC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
||||
AscendC::GlobalTensor<ElementD> const &gmD
|
||||
)
|
||||
{
|
||||
uint32_t blockM = shapeC.row();
|
||||
uint32_t blockN = shapeC.column();
|
||||
|
||||
uint32_t tileLoops = blockM;
|
||||
|
||||
for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) {
|
||||
auto gmTileC = gmC[loopIdx * blockN];
|
||||
auto &ubC = ubCList[ubListId];
|
||||
auto &ubCFp32 = ubCFp32List[ubListId];
|
||||
auto &ubMul = ubMulList[ubListId];
|
||||
auto &ubD = ubDList[ubListId];
|
||||
auto gmTileD = gmD[loopIdx * blockN];
|
||||
LayoutC layoutUbC{1, blockN};
|
||||
|
||||
// 把C从GM workspace搬到UB
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
//在UB上做把C cast成FP32
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
// 获取pertoken scale值,gmPerTokenScale的第loopIdx行
|
||||
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
// pertoken scale值与FP32的C做Muls乘法
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
// 将muls结果转回fp16/bf16
|
||||
LayoutD layoutUbD{1, blockN};
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubMulList[UB_STAGES];
|
||||
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
|
||||
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
@@ -0,0 +1,316 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
// float scale, dequant per expert
|
||||
template <
|
||||
uint32_t UB_STAGES_,
|
||||
class CType_,
|
||||
class LayoutPerTokenScale_,
|
||||
class DType_,
|
||||
class TileElemWiseMuls_,
|
||||
class TileCopy_
|
||||
>
|
||||
class BlockEpilogue <
|
||||
EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>,
|
||||
CType_,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||
DType_,
|
||||
TileElemWiseMuls_,
|
||||
TileCopy_
|
||||
> {
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(
|
||||
std::is_same_v<ElementC, half> && (std::is_same_v<ElementD, float> || std::is_same_v<ElementD, int8_t>),
|
||||
"The element type template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
static_assert(
|
||||
std::is_same_v<LayoutC, layout::RowMajor> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm<ArchTag, Gemm::GemmType<ElementPerTokenScale, LayoutPerTokenScale>>;
|
||||
|
||||
struct Params {
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||
LayoutPerTokenScale layoutPerTokenScale{};
|
||||
__gm__ ElementD *ptrD{nullptr};
|
||||
LayoutD layoutD{};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_
|
||||
) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(ptrD_), layoutD(layoutD_) {}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr uint32_t blockN = 4096;
|
||||
constexpr uint32_t ChunkTileLen = blockN / 2;
|
||||
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementC);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementD);
|
||||
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += blockN * sizeof(float);
|
||||
ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += ChunkTileLen * sizeof(float);
|
||||
ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += ChunkTileLen * sizeof(float);
|
||||
ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += HalfChunkTileLen * sizeof(float);
|
||||
ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<int32_t>();
|
||||
ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<half>();
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
|
||||
ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
// 每个tile就是1*7168,每个block是一个expert的所有token=[group[i], 7168]
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
MatrixCoord const &shapeC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale1,
|
||||
AscendC::GlobalTensor<ElementD> const &gmD,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale2,
|
||||
|
||||
uint32_t epilogueCoreNum = 40,
|
||||
Callback &&callback = Callback{}
|
||||
)
|
||||
{
|
||||
callback();
|
||||
uint32_t blockM = shapeC.row();
|
||||
uint32_t blockN = shapeC.column();
|
||||
|
||||
uint32_t tileLoops = blockM;
|
||||
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
|
||||
uint32_t subblockNum = get_block_num() * 2;
|
||||
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
|
||||
|
||||
if (subblockIdx < moveDataCoreNum) {
|
||||
return;
|
||||
}
|
||||
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
|
||||
|
||||
uint32_t perCoreData = blockM / epilogueCoreNum;
|
||||
uint32_t remainderData = blockM % epilogueCoreNum;
|
||||
|
||||
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
|
||||
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
|
||||
|
||||
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
|
||||
|
||||
uint32_t ChunkTileLen = blockN / 2;
|
||||
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
|
||||
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
|
||||
|
||||
auto gmTileC = gmC[loopIdx * blockN];
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
auto &ubD = ubDList[ubListId];
|
||||
|
||||
auto &ubCFp32 = ubCFp32List[ubListId];
|
||||
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
|
||||
auto &ubAbs = ubCFp32ChunkNAbsList[ubListId];
|
||||
// auto &ubMax = ubCFp32ChunkNMaxList[ubListId];
|
||||
auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId];
|
||||
auto &ubOutputTmp = ubAbs;
|
||||
auto &sharedUbTmpBuffer = ubReduceMax;
|
||||
auto &ubQuantS32 = ubQuantS32List[ubListId];
|
||||
auto &ubQuantF16 = ubQuantF16List[ubListId];
|
||||
|
||||
auto gmTileD = gmD[loopIdx * ChunkTileLen];
|
||||
LayoutC layoutUbC{1, blockN};
|
||||
|
||||
// 把C从GM workspace搬到UB
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
// 在UB上做把C cast成FP32
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
// 获取pertoken scale值,gmPerTokenScale的第loopIdx行
|
||||
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
// pertoken scale值与FP32的C做Muls乘法
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
//swiglue计算过程
|
||||
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
//TODO除的时候是否会对之后的数据有影响;
|
||||
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
|
||||
|
||||
//quant过程,两种方式区别;
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::ReduceMax<float>(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
||||
|
||||
//TODO两种计算方法的效率比较
|
||||
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
|
||||
auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx;
|
||||
ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SetDeqScale(static_cast<half>(1.0));
|
||||
AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
|
||||
AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
// AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
LayoutD layoutUbD{1, ChunkTileLen};
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
|
||||
if(tasksForIdx > 0){
|
||||
LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx};
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
|
||||
copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNList[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNAbsList[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNMaxList[UB_STAGES];
|
||||
AscendC::LocalTensor<int32_t> ubQuantS32List[UB_STAGES];
|
||||
AscendC::LocalTensor<half> ubQuantF16List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubPerTokenScaleOutput;
|
||||
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
CopyUbToGmDequantScale copyUbToGmDequantScale;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
|
||||
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
@@ -0,0 +1,502 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/coord.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
#include "catlass/gemm/helper.hpp"
|
||||
#include "dispatch_policy_custom.hpp"
|
||||
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||
{
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
template <
|
||||
uint32_t PRELOAD_STAGES_,
|
||||
uint32_t L1_STAGES_,
|
||||
uint32_t L0A_STAGES_,
|
||||
uint32_t L0B_STAGES_,
|
||||
uint32_t L0C_STAGES_,
|
||||
bool ENABLE_UNIT_FLAG_,
|
||||
bool ENABLE_SHUFFLE_K_,
|
||||
class L1TileShape_,
|
||||
class L0TileShape_,
|
||||
class AType_,
|
||||
class BType_,
|
||||
class CType_,
|
||||
class BiasType_,
|
||||
class TileCopy_,
|
||||
class TileMmad_
|
||||
>
|
||||
struct BlockMmad <
|
||||
MmadAtlasA2PreloadAsyncFixpipe<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
>,
|
||||
L1TileShape_,
|
||||
L0TileShape_,
|
||||
AType_,
|
||||
BType_,
|
||||
CType_,
|
||||
BiasType_,
|
||||
TileCopy_,
|
||||
TileMmad_
|
||||
> {
|
||||
public:
|
||||
// Type Aliases
|
||||
using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using L1TileShape = L1TileShape_;
|
||||
using L0TileShape = L0TileShape_;
|
||||
using ElementA = typename AType_::Element;
|
||||
using LayoutA = typename AType_::Layout;
|
||||
using ElementB = typename BType_::Element;
|
||||
using LayoutB = typename BType_::Layout;
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using TileMmad = TileMmad_;
|
||||
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
|
||||
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
|
||||
using CopyGmToL1S = Gemm::Tile::CopyGmToL1<ArchTag, Gemm::GemmType<uint64_t, layout::VectorLayout>>;
|
||||
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
|
||||
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
|
||||
|
||||
using ElementAccumulator =
|
||||
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
|
||||
using CopyL0CToGm = typename std::conditional<
|
||||
std::is_same_v<ElementA, int8_t>,
|
||||
Gemm::Tile::CopyL0CToGm<ArchTag, ElementAccumulator, CType_, Gemm::Tile::ScaleGranularity::PER_CHANNEL>,
|
||||
typename TileCopy_::CopyL0CToGm
|
||||
>::type;
|
||||
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
|
||||
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
|
||||
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
|
||||
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
|
||||
using LayoutCInL0 = layout::zN;
|
||||
|
||||
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
|
||||
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
|
||||
|
||||
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
|
||||
static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES;
|
||||
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
|
||||
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
|
||||
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
|
||||
|
||||
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
|
||||
|
||||
// L1 tile size
|
||||
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
|
||||
static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t);
|
||||
// L0 tile size
|
||||
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
|
||||
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
|
||||
|
||||
// Check LayoutC
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
|
||||
|
||||
// Check L1TileShape
|
||||
static_assert(
|
||||
(std::is_same_v<ElementA, int8_t>
|
||||
? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE
|
||||
: (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE),
|
||||
"L1TileShape exceeding the L1 space for the given data type"
|
||||
);
|
||||
|
||||
// Check L0TileShape
|
||||
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
|
||||
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
|
||||
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
|
||||
|
||||
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
|
||||
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
|
||||
|
||||
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(
|
||||
L1TileShape::M, L1TileShape::K);
|
||||
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(
|
||||
L1TileShape::K, L1TileShape::N);
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
||||
{
|
||||
syncGroupIdx = 0;
|
||||
InitL1(resource, l1BufAddrStart);
|
||||
InitL0A(resource);
|
||||
InitL0B(resource);
|
||||
InitL0C(resource);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockMmad()
|
||||
{
|
||||
SynchronizeBlock();
|
||||
for (uint32_t i = 0; i < L1_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(
|
||||
AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
|
||||
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
|
||||
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
|
||||
AscendC::GlobalTensor<uint64_t> const &gmBlockS, layout::VectorLayout const &layoutScale,
|
||||
GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0
|
||||
)
|
||||
{
|
||||
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
|
||||
|
||||
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
|
||||
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
|
||||
|
||||
uint32_t startTileIdx = 0;
|
||||
if constexpr (ENABLE_SHUFFLE_K) {
|
||||
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
|
||||
}
|
||||
|
||||
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
|
||||
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ?
|
||||
(startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount);
|
||||
|
||||
uint32_t kActual = (kTileIdx < kTileCount - 1) ?
|
||||
L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
|
||||
|
||||
// Emission load instruction from GM to L1
|
||||
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
|
||||
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
|
||||
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
|
||||
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
|
||||
// Load first matrix A tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
|
||||
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
|
||||
copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
|
||||
// Load first matrix B tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
|
||||
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
|
||||
copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);
|
||||
|
||||
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
|
||||
if (preloadCount == PRELOAD_STAGES) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
}
|
||||
|
||||
// Store the current load status
|
||||
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ?
|
||||
(l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
|
||||
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
|
||||
l1TileMmadParams.l1ListId = l1ListId;
|
||||
l1TileMmadParams.mRound = mRound;
|
||||
l1TileMmadParams.nRound = nRound;
|
||||
l1TileMmadParams.kActual = kActual;
|
||||
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
|
||||
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
|
||||
l1TileMmadParams.flag = flag;
|
||||
if (kLoopIdx == kTileCount - 1) {
|
||||
l1TileMmadParams.gmBlockC = gmBlockC;
|
||||
l1TileMmadParams.gmBlockS = gmBlockS;
|
||||
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
|
||||
l1TileMmadParams.layoutScale = layoutScale;
|
||||
l1TileMmadParams.syncLoopIdx = syncLoopIdx;
|
||||
}
|
||||
|
||||
if (preloadCount < PRELOAD_STAGES) {
|
||||
++preloadCount;
|
||||
} else {
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
}
|
||||
l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void SynchronizeBlock()
|
||||
{
|
||||
while (preloadCount > 0) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
--preloadCount;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Finalize(int32_t target, int32_t flag = 0)
|
||||
{
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / 8 + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
}
|
||||
private:
|
||||
struct L1TileMmadParams {
|
||||
uint32_t l1ListId;
|
||||
uint32_t mRound;
|
||||
uint32_t nRound;
|
||||
uint32_t kActual;
|
||||
bool isKLoopFirst;
|
||||
bool isKLoopLast;
|
||||
AscendC::GlobalTensor<ElementC> gmBlockC;
|
||||
AscendC::GlobalTensor<uint64_t> gmBlockS;
|
||||
LayoutC layoutCInGm;
|
||||
layout::VectorLayout layoutScale;
|
||||
int32_t syncLoopIdx;
|
||||
int32_t flag;
|
||||
|
||||
CATLASS_DEVICE
|
||||
L1TileMmadParams() = default;
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
|
||||
{
|
||||
uint32_t l1AOffset = l1BufAddrStart;
|
||||
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES;
|
||||
|
||||
for (uint32_t i = 0; i < L1_STAGES; ++i) {
|
||||
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
|
||||
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
|
||||
l1AEventList[i] = i;
|
||||
l1BEventList[i] = i + L1_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0A(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
|
||||
l0AEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0B(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
|
||||
l0BEventList[i] = i + L0A_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0C(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
|
||||
l0CEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void L1TileMmad(L1TileMmadParams const ¶ms)
|
||||
{
|
||||
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
|
||||
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
|
||||
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
|
||||
auto &l1ATensor = l1ATensorList[params.l1ListId];
|
||||
auto &l1BTensor = l1BTensorList[params.l1ListId];
|
||||
|
||||
auto &l0CTensor = l0CTensorList[l0CListId];
|
||||
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
|
||||
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopFirst) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
|
||||
uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ?
|
||||
L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
|
||||
|
||||
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
|
||||
uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ?
|
||||
L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
|
||||
|
||||
auto &l0ATile = l0ATensorList[l0AListId];
|
||||
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
|
||||
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
|
||||
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
if ((mPartIdx == 0) && (kPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1ListId]);
|
||||
}
|
||||
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
|
||||
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1ListId]);
|
||||
}
|
||||
|
||||
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
|
||||
uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ?
|
||||
L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
|
||||
|
||||
auto &l0BTile = l0BTensorList[l0BListId];
|
||||
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
|
||||
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
|
||||
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
if ((kPartIdx == 0) && (nPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1ListId]);
|
||||
}
|
||||
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
|
||||
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1ListId]);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
|
||||
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
|
||||
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
|
||||
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
|
||||
// If the unit flag is enabled, the unit flag is set according to the calculation progress
|
||||
uint8_t unitFlag = 0b00;
|
||||
if constexpr (ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopLast &&
|
||||
(mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
|
||||
unitFlag = 0b11;
|
||||
} else {
|
||||
unitFlag = 0b10;
|
||||
}
|
||||
}
|
||||
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.isKLoopLast) {
|
||||
auto layoutCInGm = params.layoutCInGm;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
auto layoutScale = params.layoutScale;
|
||||
auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1)));
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_FIX>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_FIX>(0);
|
||||
}
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0);
|
||||
} else {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
} else {
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11);
|
||||
} else {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
|
||||
}
|
||||
}
|
||||
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
Finalize(params.syncLoopIdx, params.flag);
|
||||
}
|
||||
}
|
||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<uint64_t> l1STensor;
|
||||
int32_t syncGroupIdx;
|
||||
int32_t l1AEventList[L1_STAGES];
|
||||
int32_t l1BEventList[L1_STAGES];
|
||||
uint32_t l1ListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
|
||||
int32_t l0AEventList[L0A_STAGES];
|
||||
uint32_t l0AListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
|
||||
int32_t l0BEventList[L0B_STAGES];
|
||||
uint32_t l0BListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
|
||||
int32_t l0CEventList[L0C_STAGES_];
|
||||
uint32_t l0CListId{0};
|
||||
|
||||
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
|
||||
uint32_t l1TileMmadParamsId{0};
|
||||
uint32_t preloadCount{0};
|
||||
|
||||
TileMmad tileMmad;
|
||||
CopyGmToL1A copyGmToL1A;
|
||||
CopyGmToL1B copyGmToL1B;
|
||||
CopyGmToL1S copyGmToL1S;
|
||||
CopyL1ToL0A copyL1ToL0A;
|
||||
CopyL1ToL0B copyL1ToL0B;
|
||||
CopyL0CToGm copyL0CToGm;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Block
|
||||
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
6
csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp
Normal file
6
csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp
Normal file
@@ -0,0 +1,6 @@
|
||||
|
||||
#ifndef CONST_ARGS_HPP
|
||||
#define CONST_ARGS_HPP
|
||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||
#endif
|
||||
@@ -0,0 +1,40 @@
|
||||
#ifndef COPY_GM_TO_L1_CUSTOM_HPP
|
||||
#define COPY_GM_TO_L1_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm::Tile {
|
||||
/// Partial specialization for nZ in and nZ out.
|
||||
template <
|
||||
class ArchTag,
|
||||
class Element
|
||||
>
|
||||
struct CopyGmToL1<ArchTag, Gemm::GemmType<Element, layout::VectorLayout>> {
|
||||
using LayoutDst = layout::VectorLayout;
|
||||
using LayoutSrc = layout::VectorLayout;
|
||||
|
||||
static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4
|
||||
|
||||
// Mehtods
|
||||
|
||||
CATLASS_DEVICE
|
||||
CopyGmToL1() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(
|
||||
AscendC::LocalTensor<Element> const &dstTensor,
|
||||
AscendC::GlobalTensor<Element> const &srcTensor,
|
||||
LayoutDst const &layoutDst, LayoutSrc const &layoutSrc)
|
||||
{
|
||||
uint32_t blockCount = 1;
|
||||
uint32_t blockLen = CeilDiv<ELE_NUM_PER_C0>(layoutSrc.shape(0));
|
||||
|
||||
AscendC::DataCopyParams repeatParams;
|
||||
|
||||
repeatParams.blockCount = blockCount;
|
||||
repeatParams.blockLen = blockLen;
|
||||
repeatParams.srcStride = 0;
|
||||
repeatParams.dstStride = 0;
|
||||
AscendC::DataCopy(dstTensor, srcTensor, repeatParams);
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif // COPY_GM_TO_L1_CUSTOM_HPP
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
#define COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm::Tile {
|
||||
template <
|
||||
class ElementAccumulator_,
|
||||
class ElementDst_,
|
||||
bool ReluEnable_
|
||||
>
|
||||
struct CopyL0CToGm<Catlass::Arch::AtlasA2,
|
||||
ElementAccumulator_,
|
||||
Gemm::GemmType<ElementDst_, layout::RowMajor>,
|
||||
ScaleGranularity::PER_CHANNEL,
|
||||
ReluEnable_>
|
||||
{
|
||||
using ArchTag = Catlass::Arch::AtlasA2;
|
||||
using ElementDst = ElementDst_;
|
||||
using ElementSrc = ElementAccumulator_;
|
||||
using LayoutSrc = Catlass::layout::zN;
|
||||
using LayoutDst = Catlass::layout::RowMajor;
|
||||
static constexpr auto quantPre = CopyL0CToGmQuantMode<ArchTag, ElementSrc, ElementDst,
|
||||
ScaleGranularity::PER_CHANNEL>::VALUE;
|
||||
static constexpr auto reluEn = ReluEnable_;
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::GlobalTensor<ElementDst> const &dst, AscendC::LocalTensor<ElementSrc> const &src, AscendC::LocalTensor<uint64_t> cbufWorkspace,
|
||||
LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::FixpipeParamsV220 intriParams;
|
||||
|
||||
// Fixpipe layout information
|
||||
intriParams.nSize = dstLayout.shape(1);
|
||||
intriParams.mSize = dstLayout.shape(0);
|
||||
intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0);
|
||||
intriParams.dstStride = dstLayout.stride(0);
|
||||
|
||||
// Fixpipe auxiliary arguments
|
||||
intriParams.quantPre = quantPre;
|
||||
intriParams.reluEn = reluEn;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
|
||||
// Call AscendC Fixpipe
|
||||
AscendC::Fixpipe<ElementDst, ElementSrc, AscendC::CFG_ROW_MAJOR>(dst, src, cbufWorkspace, intriParams);
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif // COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef DISPATH_POLICY_CUSTOM_HPP
|
||||
#define DISPATH_POLICY_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm {
|
||||
template <bool ENABLE_UNIT_FLAG_ = false, bool ENABLE_SHUFFLE_K_ = false>
|
||||
struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 {
|
||||
static constexpr uint32_t STAGES = 2;
|
||||
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
|
||||
};
|
||||
|
||||
template <uint32_t PRELOAD_STAGES_, uint32_t L1_STAGES_, uint32_t L0A_STAGES_, uint32_t L0B_STAGES_,
|
||||
uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
|
||||
struct MmadAtlasA2PreloadAsyncFixpipe :
|
||||
public MmadAtlasA2PreloadAsync<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
> {
|
||||
};
|
||||
}
|
||||
|
||||
namespace Catlass::Epilogue {
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2UnQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2PerTokenDequantQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
}
|
||||
#endif // DISPATH_POLICY_CUSTOM_HPP
|
||||
178
csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp
Normal file
178
csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp
Normal file
@@ -0,0 +1,178 @@
|
||||
#ifndef SYNC_UTIL_HPP
|
||||
#define SYNC_UTIL_HPP
|
||||
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "const_args.hpp"
|
||||
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#ifndef HCCL_COMM
|
||||
#include "shmem_api.h"
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
|
||||
*((__gm__ T *)addr) = val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
|
||||
return *((__gm__ T *)cache);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
|
||||
using namespace AscendC;
|
||||
GlobalTensor<uint8_t> global;
|
||||
global.SetGlobalBuffer(addr);
|
||||
|
||||
// Important: add hint to avoid dcci being optimized by compiler
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<uint8_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(global);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
gm_dcci((__gm__ uint8_t *)sig_addr);
|
||||
|
||||
if (*sig_addr == cmp_val) {
|
||||
return *sig_addr;
|
||||
}
|
||||
|
||||
// in case when peer pe enters next barrier
|
||||
if (*sig_addr == cmp_val + 1) {
|
||||
return *sig_addr;
|
||||
}
|
||||
} while (true);
|
||||
|
||||
// never reach
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||
class HcclShmem {
|
||||
public:
|
||||
#ifdef HCCL_COMM // hccl需要初始化hccl context
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||
size_t m_segmentSize;
|
||||
int32_t m_rank;
|
||||
int32_t m_rankSize;
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
|
||||
|
||||
m_rank = WinContext_->localUsrRankId;
|
||||
m_rankSize = WinContext_->rankSize;
|
||||
m_segmentSize = WinContext_->winSize;
|
||||
|
||||
for (int i = 0; i < m_rankSize; i++) {
|
||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
size_t SegmentSize() const {
|
||||
return m_segmentSize;
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
int32_t RankSize() const {
|
||||
return m_rankSize;
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() () const { // 无参数,返回本地peermem
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[m_rank];
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() (int32_t index) const { // 带index参数,返回远端peermem首地址
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[index];
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
#ifdef HCCL_COMM
|
||||
if (offset < 0 || offset >= m_segmentSize) {
|
||||
return nullptr;
|
||||
}
|
||||
if (rankId < 0 || rankId >= m_rankSize) {
|
||||
return nullptr;
|
||||
}
|
||||
return m_ptrArray[rankId] + offset;
|
||||
#else
|
||||
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
|
||||
#endif
|
||||
}
|
||||
|
||||
// FORCE_INLINE_AICORE
|
||||
// GM_ADDR operator () (GM_ADDR ptr, int32_t index) const { // shmem_ptr相同用法
|
||||
// #ifdef HCCL_COMM
|
||||
// size_t offset = ptr - m_ptrArray[m_rank];
|
||||
// if (offset < 0 || offset >= m_segmentSize) {
|
||||
// return nullptr;
|
||||
// }
|
||||
// if (index < 0 || index >= m_rankSize) {
|
||||
// return nullptr;
|
||||
// }
|
||||
// return m_ptrArray[index] + offset;
|
||||
// #else
|
||||
// return shmem_ptr(ptr, index);
|
||||
// #endif
|
||||
// }
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
~HcclShmem() {
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSync() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
__gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset;
|
||||
__gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048;
|
||||
int count = gm_load(sync_base) + 1;
|
||||
int vec_id = AscendC::GetBlockIdx();
|
||||
int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation();
|
||||
for(int i = vec_id; i < m_rankSize; i += vec_size) {
|
||||
__gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16;
|
||||
gm_store(sync_remote, count);
|
||||
gm_dcci((__gm__ uint8_t*)sync_remote);
|
||||
auto sync_check = sync_counter + i * 16;
|
||||
gm_signal_wait_until_eq_for_barrier(sync_check, count);
|
||||
}
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
gm_store(sync_base, count);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
__gm__ int32_t* SyncBaseAddr() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
20
csrc/dispatch_ffn_combine/op_kernel/utils/layout3d.hpp
Normal file
20
csrc/dispatch_ffn_combine/op_kernel/utils/layout3d.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef LAYOUT_3D_HPP
|
||||
#define LAYOUT_3D_HPP
|
||||
#include "kernel_operator.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
class Layout3D {
|
||||
int64_t strides[2];
|
||||
public:
|
||||
CATLASS_DEVICE
|
||||
Layout3D() {}
|
||||
CATLASS_DEVICE
|
||||
Layout3D(int64_t stride0, int64_t stride1) {
|
||||
strides[0] = stride0;
|
||||
strides[1] = stride1;
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) {
|
||||
return dim0 * strides[0] + dim1 * strides[1] + dim2;
|
||||
}
|
||||
};
|
||||
#endif // LAYOUT_3D_HPP
|
||||
25
csrc/dispatch_ffn_combine/op_kernel/utils/select_helper.hpp
Normal file
25
csrc/dispatch_ffn_combine/op_kernel/utils/select_helper.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#ifndef SELECT_HELPER_HPP
|
||||
#define SELECT_HELPER_HPP
|
||||
|
||||
#include "catlass/layout/layout.hpp"
|
||||
using namespace AscendC;
|
||||
using namespace Catlass;
|
||||
|
||||
template <typename Layout, typename ElementType, typename = void>
|
||||
struct LayoutBInitializer {
|
||||
CATLASS_DEVICE
|
||||
static Layout create(uint32_t k, uint32_t n) {
|
||||
return Layout{k, n};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout, typename ElementType>
|
||||
struct LayoutBInitializer<Layout, ElementType,
|
||||
std::enable_if_t<std::is_same_v<Layout, layout::zN>>
|
||||
> {
|
||||
CATLASS_DEVICE
|
||||
static Layout create(uint32_t k, uint32_t n) {
|
||||
return Layout::template MakeLayout<ElementType>(k, n);
|
||||
}
|
||||
};
|
||||
#endif // SELECT_HELPER_HPP
|
||||
1
csrc/third_party/catlass
vendored
Submodule
1
csrc/third_party/catlass
vendored
Submodule
Submodule csrc/third_party/catlass added at 716fd7baa7
@@ -37,59 +37,60 @@
|
||||
namespace vllm_ascend {
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping, aclrtStream stream) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
aclrtMemcpyKind memcpy_type;
|
||||
const torch::Tensor& block_mapping, aclrtStream stream)
|
||||
{
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
aclrtMemcpyKind memcpy_type;
|
||||
|
||||
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same npu");
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
|
||||
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
|
||||
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
|
||||
}
|
||||
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same npu");
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
|
||||
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
|
||||
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
|
||||
}
|
||||
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
const int64_t max_src_block = src.size(0);
|
||||
const int64_t max_dst_block = dst.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
|
||||
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
|
||||
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
|
||||
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
const int64_t max_src_block = src.size(0);
|
||||
const int64_t max_dst_block = dst.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
|
||||
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
|
||||
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
|
||||
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
|
||||
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
|
||||
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
|
||||
src_ptr + src_offset, block_size_in_bytes,
|
||||
memcpy_type, stream);
|
||||
}
|
||||
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
|
||||
src_ptr + src_offset, block_size_in_bytes,
|
||||
memcpy_type, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
|
||||
{
|
||||
|
||||
const c10_npu::OptionalNPUGuard npuGuard(
|
||||
const c10_npu::OptionalNPUGuard npuGuard(
|
||||
(!x.device().is_cpu()) ? x.device() : y.device()
|
||||
);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
swap_blocks_impl(x, y, z, stream);
|
||||
return;
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
swap_blocks_impl(x, y, z, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
AscendType get_dtype_from_torch(at::ScalarType scalarType)
|
||||
@@ -617,7 +618,33 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor& dispatch_ffn_combine(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& weight1,
|
||||
const at::Tensor& weight2,
|
||||
const at::Tensor& expert_idx,
|
||||
const at::Tensor& scale1,
|
||||
const at::Tensor& scale2,
|
||||
const at::Tensor& probs,
|
||||
c10::string_view group,
|
||||
int64_t max_output_size,
|
||||
at::Tensor& out
|
||||
) {
|
||||
char *group_ep_ptr = const_cast<char *>(group.data());
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
|
||||
x,
|
||||
weight1,
|
||||
weight2,
|
||||
expert_idx,
|
||||
scale1,
|
||||
scale2,
|
||||
probs,
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer(
|
||||
@@ -810,4 +837,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
||||
|
||||
ops.def(
|
||||
"dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx,"
|
||||
" Tensor scale1, Tensor scale2, Tensor probs, str group,"
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
}
|
||||
|
||||
@@ -156,7 +156,21 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor& dispatch_ffn_combine_meta(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& weight1,
|
||||
const at::Tensor& weight2,
|
||||
const at::Tensor& expert_idx,
|
||||
const at::Tensor& scale1,
|
||||
const at::Tensor& scale2,
|
||||
const at::Tensor& probs,
|
||||
c10::string_view group,
|
||||
int64_t max_output_size,
|
||||
at::Tensor& out
|
||||
) {
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer_meta(
|
||||
@@ -244,5 +258,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta);
|
||||
// Sparse flash attention
|
||||
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
|
||||
// MoE dispatch-ffn-combine
|
||||
ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta);
|
||||
}
|
||||
}
|
||||
|
||||
168
tests/e2e/nightly/ops/test_dispatch_ffn_combine.py
Normal file
168
tests/e2e/nightly/ops/test_dispatch_ffn_combine.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
class TestDisptachFFNCombine:
|
||||
|
||||
def __init__(self, rank, world_size, port):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_ip = "127.0.0.1"
|
||||
self.port = port
|
||||
|
||||
def get_hcomm(self, comm_group):
|
||||
hcomm_info = None
|
||||
if torch.__version__ > "2.0.1":
|
||||
hcomm_info = comm_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.rank)
|
||||
else:
|
||||
hcomm_info = comm_group.get_hccl_comm_name(self.rank)
|
||||
return hcomm_info
|
||||
|
||||
def setup_ep_tp(
|
||||
self,
|
||||
rank,
|
||||
tp_size,
|
||||
ep_size,
|
||||
backend_type,
|
||||
ep_ranks_list=None,
|
||||
tp_ranks_list=None,
|
||||
):
|
||||
for i in range(tp_size):
|
||||
if ep_ranks_list:
|
||||
ep_ranks = ep_ranks_list[i]
|
||||
else:
|
||||
ep_ranks = [x + ep_size * i for x in range(ep_size)]
|
||||
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
|
||||
if rank in ep_ranks:
|
||||
ep_group_tmp = ep_group
|
||||
for i in range(ep_size):
|
||||
if tp_ranks_list:
|
||||
tp_ranks = tp_ranks_list[i]
|
||||
else:
|
||||
tp_ranks = [x * ep_size + i for x in range(tp_size)]
|
||||
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
|
||||
if rank in tp_ranks:
|
||||
tp_group_tmp = tp_group
|
||||
return ep_group_tmp, tp_group_tmp
|
||||
|
||||
def generate_hcom(self):
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
dist.init_process_group(
|
||||
backend="hccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
init_method=f"tcp://127.0.0.1:{self.port}",
|
||||
)
|
||||
|
||||
ep_size = 0
|
||||
tp_size = self.world_size
|
||||
hcomm_info_dist = {
|
||||
"default_pg_info": None,
|
||||
"ep_hcomm_info": None,
|
||||
"group_ep": None,
|
||||
"tp_hcomm_info": None,
|
||||
"group_tp": None,
|
||||
}
|
||||
if ep_size and tp_size:
|
||||
group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size,
|
||||
"hccl", None, None)
|
||||
hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep)
|
||||
hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp)
|
||||
hcomm_info_dist["group_ep"] = group_ep
|
||||
hcomm_info_dist["group_tp"] = group_tp
|
||||
else:
|
||||
if dist.is_available():
|
||||
default_pg = _get_default_group()
|
||||
hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg)
|
||||
hcomm_info = hcomm_info_dist["default_pg_info"]
|
||||
self.hcomm_info = hcomm_info
|
||||
|
||||
def run_npu_out(self) -> bool:
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
m = 2 # token-num 32
|
||||
k = 4 # hidden_size 7168
|
||||
n = 4 # mid-hidden-size 4096
|
||||
topk = 2
|
||||
e = 2 # expert-num-per-rank 16
|
||||
k2 = n // 2
|
||||
n2 = k
|
||||
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
weight1 = self.generate_random_tensor((e, k, n),
|
||||
dtype=torch.int8).npu()
|
||||
weight1 = torch_npu.npu_format_cast(weight1, 29)
|
||||
weight2 = self.generate_random_tensor((e, k2, n2),
|
||||
dtype=torch.int8).npu()
|
||||
weight2 = torch_npu.npu_format_cast(weight2, 29)
|
||||
|
||||
expert_idx = torch.randint(0,
|
||||
self.world_size * e, (m, topk),
|
||||
dtype=torch.int32).npu()
|
||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=x,
|
||||
weight1=weight1,
|
||||
weight2=weight2,
|
||||
expert_idx=expert_idx,
|
||||
scale1=scale1,
|
||||
scale2=scale2,
|
||||
probs=probs,
|
||||
group=self.hcomm_info,
|
||||
max_output_size=512,
|
||||
out=out,
|
||||
)
|
||||
return True
|
||||
|
||||
def generate_random_tensor(self, size, dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
return torch.randn(size=size, dtype=dtype)
|
||||
elif dtype is torch.int8:
|
||||
return torch.randint(-16, 16, size=size, dtype=dtype)
|
||||
elif dtype is torch.int32:
|
||||
return torch.randint(-1024, 1024, size=size, dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
||||
op.generate_hcom()
|
||||
out = op.run_npu_out()
|
||||
q.put(out)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_ffn_combine_kernel():
|
||||
world_size = 2
|
||||
mp.set_start_method("fork", force=True)
|
||||
|
||||
q = mp.SimpleQueue()
|
||||
p_list = []
|
||||
port = 29501 + random.randint(0, 10000)
|
||||
|
||||
for rank in range(world_size):
|
||||
p = mp.Process(target=worker, args=(rank, world_size, port, q))
|
||||
p.start()
|
||||
p_list.append(p)
|
||||
|
||||
results = [q.get() for _ in range(world_size)]
|
||||
|
||||
for p in p_list:
|
||||
p.join()
|
||||
|
||||
assert all(results)
|
||||
@@ -52,6 +52,7 @@ class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
FUSED_ALLTOALL = 3
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -520,7 +520,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out, fused_output
|
||||
|
||||
@@ -44,6 +44,8 @@ def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
|
||||
moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
@@ -243,3 +245,69 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
out = torch.empty_like(hidden_states)
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1,
|
||||
weight2=w2,
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale,
|
||||
scale2=w2_scale,
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -513,6 +513,11 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=self.ep_group)
|
||||
backend = self.ep_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
|
||||
@@ -249,8 +249,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
|
||||
} or forward_context.sp_enabled:
|
||||
if moe_comm_type in {
|
||||
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
@@ -232,13 +233,15 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_flag = get_forward_context(
|
||||
).moe_comm_type == MoECommType.FUSED_ALLTOALL
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w1=w1[0] if fused_flag else w1,
|
||||
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
|
||||
w2=w2[0] if fused_flag else w2,
|
||||
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
@@ -270,6 +273,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
layer.fused_w1_scale = scale_from_float_to_int64(
|
||||
layer.w13_weight_scale.data)
|
||||
layer.fused_w2_scale = scale_from_float_to_int64(
|
||||
layer.w2_weight_scale.data)
|
||||
|
||||
if self.dynamic_eplb:
|
||||
layer.w13_weight_list = [
|
||||
weight.clone()
|
||||
@@ -292,3 +301,11 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def scale_from_float_to_int64(scale):
|
||||
import numpy as np
|
||||
scale = torch.from_numpy(
|
||||
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
|
||||
dtype=np.int32).astype(np.int64)).to(scale.device)
|
||||
return scale
|
||||
|
||||
@@ -911,6 +911,9 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
"dp": {
|
||||
"hccl_buffer_size": calculate_dp_buffer_size()
|
||||
},
|
||||
"ep": {
|
||||
"hccl_buffer_size": calculate_ep_buffer_size()
|
||||
},
|
||||
}
|
||||
return hccl_config_map.get(group_name, get_default_buffer_config())
|
||||
|
||||
@@ -932,6 +935,30 @@ def calculate_dp_buffer_size() -> int:
|
||||
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
|
||||
|
||||
|
||||
def calculate_ep_buffer_size() -> int:
|
||||
"""
|
||||
formula of ep buffer size:
|
||||
batch_size * hidden_size * topk * 4
|
||||
"""
|
||||
ep_buffer_size = _DEFAULT_BUFFER_SIZE
|
||||
try:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
hidden_size = hf_config.hidden_size
|
||||
topk = getattr(hf_config, "num_experts_per_token", 1)
|
||||
batch_size = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
int8_size = torch.iinfo(torch.int8).bits // 8
|
||||
bf16_size = torch.finfo(torch.bfloat16).bits // 8
|
||||
ep_buffer_size = math.ceil(
|
||||
(batch_size * hidden_size * topk *
|
||||
(int8_size * 2 + bf16_size)) / (1024 * 1024))
|
||||
except Exception:
|
||||
pass
|
||||
return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)
|
||||
|
||||
|
||||
# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
|
||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||
|
||||
@@ -2217,8 +2217,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
return None
|
||||
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
||||
'moe_quantize', None)
|
||||
quant_type = getattr(
|
||||
self.vllm_config.model_config.hf_config, 'moe_quantize',
|
||||
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
|
||||
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
@@ -2237,7 +2238,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
elif soc_version in {AscendDeviceType._910_93}:
|
||||
moe_comm_type = (MoECommType.MC2
|
||||
if num_tokens <= self.mc2_tokens_capacity else
|
||||
MoECommType.ALLTOALL)
|
||||
MoECommType.FUSED_ALLTOALL if quant_type
|
||||
== "w8a8_dynamic" else MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user