[Kernel] add custom op DispatchGmmCombineDecode (#4139)
#### What this PR does / why we need it? add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl, python Api, pytest. vLLM version: v0.11.0 vLLM main:24d6314718- vLLM version: v0.12.0 - vLLM main:ad32e3e19cSigned-off-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangqiankun <wangqiankun13@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -138,3 +138,21 @@ jobs:
|
||||
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
|
||||
tests: ${{ matrix.test_config.tests }}
|
||||
name: ${{ matrix.test_config.name }}
|
||||
custom-ops-tests:
|
||||
name: test ops
|
||||
if: always() && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch')
|
||||
needs: multi-node-tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test_config:
|
||||
- name: custom-op-dispatch_gmm_combine_decode
|
||||
os: linux-aarch64-a3-16
|
||||
tests: tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py
|
||||
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
|
||||
with:
|
||||
runner: ${{ matrix.test_config.os }}
|
||||
vllm: v0.12.0
|
||||
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
|
||||
tests: ${{ matrix.test_config.tests }}
|
||||
name: ${{ matrix.test_config.name }}
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
ROOT_DIR=$1
|
||||
SOC_VERSION=$2
|
||||
|
||||
git config --global --add safe.directory "$ROOT_DIR"
|
||||
|
||||
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
# ASCEND310P series
|
||||
# currently, no custom aclnn ops for ASCEND310 series
|
||||
@@ -13,11 +11,41 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
|
||||
exit 0
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
# ASCEND910B (A2) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
|
||||
# depdendency: catlass
|
||||
git config --global --add safe.directory "$ROOT_DIR"
|
||||
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass/include
|
||||
if [[ ! -d "${CATLASS_PATH}" ]]; then
|
||||
echo "depdendency catlass is missing, try to fetch it..."
|
||||
if ! git submodule update --init --recursive; then
|
||||
echo "fetch failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
# depdendency: cann-toolkit file moe_distribute_base.h
|
||||
HCCL_STRUCT_FILE_PATH=$(find -L "${ASCEND_TOOLKIT_HOME}" -name "moe_distribute_base.h" 2>/dev/null | head -n1)
|
||||
if [ -z "$HCCL_STRUCT_FILE_PATH" ]; then
|
||||
echo "cannot find moe_distribute_base.h file in CANN env"
|
||||
exit 1
|
||||
fi
|
||||
# for dispatch_gmm_combine_decode
|
||||
yes | cp "${HCCL_STRUCT_FILE_PATH}" "${ROOT_DIR}/csrc/dispatch_gmm_combine_decode/op_kernel"
|
||||
# for dispatch_ffn_combine
|
||||
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
|
||||
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
|
||||
TARGET_FILE="$TARGET_DIR/$(basename "$HCCL_STRUCT_FILE_PATH")"
|
||||
|
||||
echo "*************************************"
|
||||
echo $HCCL_STRUCT_FILE_PATH
|
||||
echo "$TARGET_DIR"
|
||||
cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR"
|
||||
|
||||
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
|
||||
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine;dispatch_gmm_combine_decode;"
|
||||
SOC_ARG="ascend910_93"
|
||||
else
|
||||
# others
|
||||
@@ -25,29 +53,6 @@ else
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git submodule init
|
||||
git submodule update
|
||||
|
||||
|
||||
# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h
|
||||
file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1)
|
||||
if [ -z "$file_path" ]; then
|
||||
echo "cannot find moe_distribute_base.h file in CANN env"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
|
||||
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
|
||||
TARGET_FILE="$TARGET_DIR/$(basename "$file_path")"
|
||||
|
||||
echo "*************************************"
|
||||
echo $file_path
|
||||
echo "$TARGET_DIR"
|
||||
cp "$file_path" "$TARGET_DIR"
|
||||
|
||||
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
|
||||
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
|
||||
|
||||
|
||||
# build custom ops
|
||||
cd csrc
|
||||
|
||||
59
csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt
Normal file
59
csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
set(_DISPATCH_GMM_INC_OPTS)
|
||||
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
list(APPEND _DISPATCH_GMM_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
|
||||
else()
|
||||
message(FATAL_ERROR "dependency catlass is missing, you can fetch it by running 'git submodule update --init --recursive'")
|
||||
endif()
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME DispatchGmmCombineDecode
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
${_DISPATCH_GMM_INC_OPTS}
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnnInner PRIVATE
|
||||
dispatch_gmm_combine_decode_def.cpp
|
||||
)
|
||||
|
||||
target_sources(opapi PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(aclnn_ops_train PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
|
||||
target_sources(aclnn_ops_infer PRIVATE
|
||||
aclnn_dispatch_gmm_combine_decode.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
dispatch_gmm_combine_decode_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
dispatch_gmm_combine_decode_proto.cpp
|
||||
)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include <string.h>
|
||||
#include "graph/types.h"
|
||||
#include "aclnn/opdev/platform.h"
|
||||
#include "aclnn_dispatch_gmm_combine_decode.h"
|
||||
|
||||
enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_MTE,
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
||||
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
|
||||
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
||||
output, epRecvCount, workspaceSize, executor);
|
||||
}
|
||||
|
||||
aclnnStatus aclnnDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream)
|
||||
{
|
||||
if (NnopbaseSetHcclServerType) {
|
||||
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
|
||||
} else {
|
||||
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
|
||||
}
|
||||
}
|
||||
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE
|
||||
#define DISPATCH_GMM_COMBINE_DECODE
|
||||
|
||||
#include "aclnn/acl_meta.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *x,
|
||||
const aclTensor *expertIds,
|
||||
const aclTensor *gmm1PermutedWeight,
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
int64_t moeExpertNum,
|
||||
int64_t shareExpertNum,
|
||||
int64_t shareExpertRankNum,
|
||||
int64_t quantMode,
|
||||
int64_t globalBs,
|
||||
const aclTensor *output,
|
||||
const aclTensor *epRecvCount,
|
||||
uint64_t *workspaceSize,
|
||||
aclOpExecutor **executor);
|
||||
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
|
||||
void *workspace,
|
||||
uint64_t workspaceSize,
|
||||
aclOpExecutor *executor,
|
||||
aclrtStream stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class DispatchGmmCombineDecode : public OpDef
|
||||
{
|
||||
public:
|
||||
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm1_permuted_weight")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("gmm1_permuted_weight_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm2_weight")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
this->Input("gmm2_weight_scale")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_smooth_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("output")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("ep_recv_count")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("group_ep").String();
|
||||
this->Attr("ep_rank_size").Int();
|
||||
this->Attr("ep_rank_id").Int();
|
||||
this->Attr("moe_expert_num").Int();
|
||||
this->Attr("share_expert_num").Int();
|
||||
this->Attr("share_expert_rank_num").Int();
|
||||
this->Attr("quant_mode").Int();
|
||||
this->Attr("global_bs").Int();
|
||||
|
||||
this->MC2().HcclGroup({"group_ep"});
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DispatchGmmCombineDecode);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ge {
|
||||
constexpr uint32_t EXPAND_X_INDEX = 0;
|
||||
constexpr uint32_t EXPERT_IDS_INDEX = 1;
|
||||
constexpr uint32_t OUTPUT_X_INDEX = 0;
|
||||
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
|
||||
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
// infer output shape
|
||||
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
|
||||
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
|
||||
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
|
||||
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
|
||||
recvCountOutShape == nullptr) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int bs = expertIdsShape->GetDim(0);
|
||||
int h = expandXShape->GetDim(1);
|
||||
|
||||
expandXOutShape->SetDimNum(expandXShape->GetDimNum());
|
||||
expandXOutShape->SetDim(0, bs);
|
||||
expandXOutShape->SetDim(1, h);
|
||||
|
||||
// infer recvCount shape
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
|
||||
|
||||
OPS_ERR_IF(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(epRankSizePtr == nullptr, OPS_LOG_E(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(sharedExpertRankNumPtr == nullptr, OPS_LOG_E(nodeName, "sharedExpertRankNumPtr is nullptr."),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
|
||||
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
|
||||
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
|
||||
|
||||
recvCountOutShape->SetDimNum(1);
|
||||
bool isShareExpert = (epRankId < sharedExpertRankNum);
|
||||
if (isShareExpert) {
|
||||
recvCountOutShape->SetDim(0, epRankSize);
|
||||
} else {
|
||||
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
|
||||
{
|
||||
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
|
||||
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
|
||||
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType);
|
||||
} // namespace ge
|
||||
@@ -0,0 +1,335 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "log/ops_log.h"
|
||||
#include "error/ops_error.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/hccl/hccl_tiling.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace {
|
||||
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint32_t GM_ALIGN_SIZE = 512;
|
||||
constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2;
|
||||
constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024;
|
||||
constexpr uint32_t CUBE_WORKSPACE_STAGE = 4;
|
||||
constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024;
|
||||
|
||||
constexpr uint32_t INPUT_X_INDEX = 0;
|
||||
constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
|
||||
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
|
||||
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
|
||||
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
|
||||
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
|
||||
|
||||
constexpr uint32_t MIN_BATCH_SIZE = 1;
|
||||
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
||||
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
||||
constexpr uint32_t SUPPORT_TOP_K = 12;
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
||||
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
||||
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
|
||||
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static size_t CeilUp(size_t x, size_t y)
|
||||
{
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
|
||||
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
|
||||
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
|
||||
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
|
||||
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
|
||||
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
|
||||
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
|
||||
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
||||
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
OPS_ERR_IF(
|
||||
tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
|
||||
OPS_LOG_E(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
OPS_ERR_IF(
|
||||
gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
|
||||
OPS_LOG_E(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
OPS_ERR_IF(topK > SUPPORT_TOP_K, OPS_LOG_E(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
|
||||
return ge::GRAPH_FAILED);
|
||||
uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
|
||||
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
if (globalBatchSize == 0) {
|
||||
globalBatchSize = epRankSize * batchSize;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize;
|
||||
} else {
|
||||
OPS_ERR_IF(globalBatchSize < 0, OPS_LOG_E(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(globalBatchSize % epRankSize > 0,
|
||||
OPS_LOG_E(nodeName, "globalBatchSize must be divisible by epRankSize."),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t recvAivNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aivNum / 2;
|
||||
OPS_ERR_IF(
|
||||
moeExpertNumPerRank > recvAivNum,
|
||||
OPS_LOG_E(nodeName, "moeExpertNumPerRank must <= (aivNum/2)(%u), but got %u", recvAivNum, moeExpertNumPerRank),
|
||||
return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
|
||||
|
||||
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
|
||||
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
|
||||
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
|
||||
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_NUM_INDEX);
|
||||
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
|
||||
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
|
||||
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
|
||||
|
||||
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
|
||||
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
|
||||
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
|
||||
uint32_t sharedExpertNum = static_cast<uint32_t>(*sharedExpertNumPtr);
|
||||
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
|
||||
uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
|
||||
|
||||
OPS_ERR_IF(epRankId < 0, OPS_LOG_E(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(epRankId >= epRankSize, OPS_LOG_E(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OPS_LOG_E(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum <= 0, OPS_LOG_E(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(sharedExpertNum != 1, OPS_LOG_E(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
|
||||
OPS_LOG_E(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
groupEp = std::string(groupEpPtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum;
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
|
||||
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
OPS_LOG_D(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str());
|
||||
uint32_t opType = OP_TYPE_ALL_TO_ALL;
|
||||
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
|
||||
std::string algConfigAllGatherStr = "AllGather=level0:ring";
|
||||
|
||||
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
|
||||
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
|
||||
}
|
||||
|
||||
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName,
|
||||
DispatchGmmCombineDecodeTilingData &tilingData)
|
||||
{
|
||||
size_t *workSpaces = context->GetWorkspaceSizes(1);
|
||||
OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
|
||||
size_t maxTokenNum;
|
||||
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
|
||||
uint32_t maxBatchSize = globalBs / epRankSize;
|
||||
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
|
||||
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
|
||||
if (epRankId < sharedExpertRankNum) {
|
||||
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
|
||||
} else {
|
||||
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
|
||||
}
|
||||
|
||||
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t CVSwapBufferSize =
|
||||
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
|
||||
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
|
||||
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
|
||||
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData<DispatchGmmCombineDecodeTilingData>();
|
||||
OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
|
||||
std::string groupEp = "";
|
||||
|
||||
const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX);
|
||||
OPS_ERR_IF(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "x shape dims must be 2, but current dim num is %lu.",
|
||||
xStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize;
|
||||
const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize;
|
||||
|
||||
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX);
|
||||
OPS_ERR_IF(expertIdsStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIds shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
||||
OPS_LOG_E(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.",
|
||||
expertIdsStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
|
||||
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
|
||||
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
||||
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
|
||||
return ge::GRAPH_FAILED);
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
||||
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
||||
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp);
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
|
||||
context->SetTilingKey(0);
|
||||
} else {
|
||||
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
|
||||
}
|
||||
context->SetBlockDim(aicNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context)
|
||||
{
|
||||
ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct DispatchGmmCombineDecodeCompileInfo {};
|
||||
ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context)
|
||||
{
|
||||
(void)context;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(DispatchGmmCombineDecode)
|
||||
.Tiling(DispatchGmmCombineDecodeTilingFunc)
|
||||
.TilingParse<DispatchGmmCombineDecodeCompileInfo>(TilingParseForDispatchGmmCombineDecode);
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "dispatch_gmm_combine_decode.h"
|
||||
#include <kernel_operator.h>
|
||||
#include "lib/matmul_intf.h"
|
||||
|
||||
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
icache_preload(8);
|
||||
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
|
||||
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
||||
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
||||
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,433 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE_H
|
||||
#define DISPATCH_GMM_COMBINE_DECODE_H
|
||||
|
||||
#include "lib/matmul_intf.h"
|
||||
#include <kernel_operator.h>
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/arch.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/epilogue/tile/tile_broadcast_mul.hpp"
|
||||
#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp"
|
||||
#include "catlass/epilogue/tile/tile_swizzle.hpp"
|
||||
#include "catlass/gemm/block/block_swizzle.hpp"
|
||||
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h"
|
||||
#include "catlass/gemm/gemm_type.hpp"
|
||||
#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h"
|
||||
#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h"
|
||||
#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h"
|
||||
#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h"
|
||||
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h"
|
||||
|
||||
#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h"
|
||||
|
||||
#include "dispatch_gmm_combine_decode_tiling.h"
|
||||
#include "dispatch_gmm_combine_decode_base.h"
|
||||
|
||||
using namespace Catlass;
|
||||
|
||||
using MmadAtlasA2Custom =
|
||||
Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
|
||||
CUSTOM_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG,
|
||||
CUSTOM_ENABLE_SHUFFLE_K>;
|
||||
|
||||
using Gmm1L1TileShape = GemmShape<GMM1_L1M, GMM1_L1N, GMM1_L1K>;
|
||||
using Gmm1L0TileShape = GemmShape<GMM1_L1M, GMM1_L1N, GMM1_L0K>;
|
||||
using Gmm1EpilogueTileShape = MatrixShape<GMM1_EPIM, Gmm1L1TileShape::N>;
|
||||
using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM1_SWIZZLE_OFFSET, GMM1_SWIZZLE_DIRECTION>;
|
||||
|
||||
using Gmm2L1TileShape = GemmShape<GMM2_L1M, GMM2_L1N, GMM2_L1K>;
|
||||
using Gmm2L0TileShape = GemmShape<Gmm2L1TileShape::M, Gmm2L1TileShape::N, GMM2_L0K>;
|
||||
using Gmm2EpilogueTileShape = MatrixShape<GMM2_EPIM, Gmm2L1TileShape::N>;
|
||||
using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM2_SWIZZLE_OFFSET, GMM2_SWIZZLE_DIRECTION>;
|
||||
using Gmm2DispatchPolicy =
|
||||
Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA<CUSTOM_PRELOAD_STAGES, GMM2_L1A_STAGES, GMM2_L1B_STAGES,
|
||||
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
|
||||
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
|
||||
|
||||
template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
|
||||
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
|
||||
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
|
||||
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
|
||||
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
|
||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
|
||||
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
|
||||
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
||||
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
|
||||
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
|
||||
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
|
||||
{
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
using DispatchPolicy = DispatchPolicy_;
|
||||
using L1TileShape = L1TileShape_;
|
||||
using L0TileShape = L0TileShape_;
|
||||
|
||||
using XType = XType_;
|
||||
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||
using BType = Gemm::GemmType<int8_t, layout::zN>;
|
||||
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
|
||||
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
||||
|
||||
constexpr uint32_t ubStages = 1;
|
||||
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu<ubStages, 0>;
|
||||
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using DType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
|
||||
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
|
||||
using EpilogueTileShape = EpilogueTileShape_;
|
||||
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
|
||||
using TileBroadcastOneBlk =
|
||||
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
|
||||
using TileOneBlkColumnBroadcastMul =
|
||||
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
|
||||
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
|
||||
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
|
||||
|
||||
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
|
||||
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
|
||||
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
|
||||
|
||||
using BlockScheduler = BlockScheduler_;
|
||||
|
||||
// kernel level
|
||||
using ElementGroupList = int64_t;
|
||||
|
||||
using GemmKernel = typename std::conditional<
|
||||
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
||||
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
typename GemmKernel::Params params{problemShape,
|
||||
groupCount,
|
||||
gmGroupList,
|
||||
gmA,
|
||||
layoutA,
|
||||
gmB,
|
||||
layoutB,
|
||||
gmScale,
|
||||
layoutScale,
|
||||
gmPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
gmD,
|
||||
layoutD,
|
||||
gmDequantScale,
|
||||
layoutDequantScale,
|
||||
gmWorkspace,
|
||||
gmX,
|
||||
debugGm,
|
||||
gmexpertIds,
|
||||
gmExpandIdx,
|
||||
gmEpSendCount,
|
||||
gmResvered,
|
||||
gmOutputRecvCount,
|
||||
epRankSize,
|
||||
epRankId,
|
||||
moeExpertNum,
|
||||
moeExpertNumPerRank,
|
||||
sharedExpertNum,
|
||||
sharedExpertRankNum,
|
||||
quantMode,
|
||||
globalBs,
|
||||
bs,
|
||||
topK,
|
||||
tokenLen};
|
||||
// call a kernel
|
||||
GemmKernel gemm;
|
||||
gemm(params);
|
||||
} else {
|
||||
typename GemmKernel::Params params{problemShape,
|
||||
groupCount,
|
||||
gmGroupList,
|
||||
gmA,
|
||||
layoutA,
|
||||
gmB,
|
||||
layoutB,
|
||||
gmScale,
|
||||
layoutScale,
|
||||
gmPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
gmD,
|
||||
layoutD,
|
||||
gmDequantScale,
|
||||
layoutDequantScale,
|
||||
gmWorkspace};
|
||||
// call a kernel
|
||||
GemmKernel gemm;
|
||||
gemm(params);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
|
||||
class DispatchPolicy_ = MmadAtlasA2Custom>
|
||||
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
|
||||
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
|
||||
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
|
||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||
GM_ADDR gmWorkspace, void *combiner)
|
||||
{
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
using DispatchPolicy = DispatchPolicy_;
|
||||
using L1TileShape = L1TileShape_;
|
||||
using L0TileShape = L0TileShape_;
|
||||
|
||||
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||
using BType = Gemm::GemmType<int8_t, layout::nZ>;
|
||||
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||
|
||||
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
||||
|
||||
constexpr uint32_t ubStages = 1;
|
||||
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine<ubStages, EXEC_FLAG>;
|
||||
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
|
||||
|
||||
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
|
||||
|
||||
using EpilogueTileShape = EpilogueTileShape_;
|
||||
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
|
||||
using TileBroadcastOneBlk =
|
||||
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
|
||||
using TileOneBlkColumnBroadcastMul =
|
||||
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
|
||||
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
|
||||
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
|
||||
|
||||
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
|
||||
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
|
||||
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
|
||||
|
||||
using BlockScheduler = BlockScheduler_;
|
||||
|
||||
// kernel level
|
||||
using ElementGroupList = int64_t;
|
||||
using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<
|
||||
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
||||
|
||||
typename GemmKernel::Params params{
|
||||
problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale,
|
||||
layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner};
|
||||
|
||||
// call a kernel
|
||||
GemmKernel gemm;
|
||||
gemm(params);
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
class DispatchGmmCombineDecode
|
||||
{
|
||||
public:
|
||||
__aicore__ inline DispatchGmmCombineDecode(){};
|
||||
__aicore__ inline void Init(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
private:
|
||||
GM_ADDR gmX_;
|
||||
GM_ADDR gmexpertIds_;
|
||||
GM_ADDR gmPermuteWeight1_;
|
||||
GM_ADDR gmPermuteScale1_;
|
||||
GM_ADDR gmWeight2_;
|
||||
GM_ADDR gmScale2_;
|
||||
GM_ADDR gmOutput_;
|
||||
GM_ADDR gmOutputRecvCount_;
|
||||
GM_ADDR workspaceGM_;
|
||||
GM_ADDR gmSmoothScales_;
|
||||
GM_ADDR gmexpertScales_;
|
||||
|
||||
uint32_t m_{0};
|
||||
uint32_t n_{0};
|
||||
uint32_t k_{0};
|
||||
uint32_t groupCount_{0};
|
||||
uint32_t n2_{0};
|
||||
uint32_t k2_{0};
|
||||
uint32_t globalRankId_{0};
|
||||
uint32_t winSizePerRank_{0};
|
||||
uint32_t blockDim_{0};
|
||||
uint32_t epRankSize_{0};
|
||||
uint32_t epRankId_{0};
|
||||
uint32_t moeExpertNum_{0};
|
||||
uint32_t moeExpertNumPerRank_{0};
|
||||
uint32_t sharedExpertNum_{0};
|
||||
uint32_t sharedExpertRankNum_{0};
|
||||
uint32_t quantMode_{0};
|
||||
uint32_t globalBs_{0};
|
||||
uint32_t bs_{0};
|
||||
uint32_t maxBs_{0};
|
||||
uint32_t topK_{0};
|
||||
|
||||
AscendC::TPipe *tpipe_{nullptr};
|
||||
__gm__ HcclOpResParam *winContext_{nullptr};
|
||||
const DispatchGmmCombineDecodeTilingData *tilingData_;
|
||||
};
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
blockDim_ = AscendC::GetBlockNum();
|
||||
winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
|
||||
|
||||
gmSmoothScales_ = expert_smooth_scales; // not used now
|
||||
gmX_ = x; // input token
|
||||
gmexpertIds_ = expert_ids;
|
||||
gmPermuteWeight1_ = gmm1_permuted_weight;
|
||||
gmPermuteScale1_ = gmm1_permuted_weight_scale;
|
||||
gmWeight2_ = gmm2_weight;
|
||||
gmScale2_ = gmm2_weight_scale;
|
||||
gmOutput_ = output;
|
||||
gmOutputRecvCount_ = outputRecvCount;
|
||||
workspaceGM_ = workspaceGM;
|
||||
gmexpertScales_ = expert_scales;
|
||||
tilingData_ = tilingData;
|
||||
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum;
|
||||
sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode;
|
||||
globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
|
||||
bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
maxBs_ = globalBs_ / epRankSize_;
|
||||
|
||||
bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
|
||||
if (isShareExpert) {
|
||||
m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
|
||||
} else {
|
||||
m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
|
||||
}
|
||||
|
||||
n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
n2_ = k_;
|
||||
k2_ = n_ / 2;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
{
|
||||
GemmCoord gmm1ProblemShape{m_, n_, k_};
|
||||
GemmCoord gmm2ProblemShape{m_, n2_, k2_};
|
||||
|
||||
layout::RowMajor layoutX1{m_, k_};
|
||||
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(k_, n_);
|
||||
layout::VectorLayout layoutScale1{n_};
|
||||
layout::VectorLayout layoutPerTokenScale1{m_};
|
||||
layout::RowMajor layoutX2{m_, k2_};
|
||||
layout::nZ layoutWeight2 = layout::nZ::template MakeLayout<int8_t>(k2_, n2_);
|
||||
layout::VectorLayout layoutScale2{n2_};
|
||||
layout::VectorLayout layoutPerTokenScale2{m_};
|
||||
layout::RowMajor layoutOutput{m_, n2_};
|
||||
|
||||
size_t workspaceOffset = 0;
|
||||
constexpr int32_t resveredWorkSpaceSize = 256 * 1024;
|
||||
GM_ADDR gmX2 = workspaceGM_;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(int8_t));
|
||||
GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
|
||||
GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset;
|
||||
|
||||
GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (GMM1_L1M * GMM1_L1N) *
|
||||
WORKSPACE_STAGES * sizeof(int32_t));
|
||||
GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(float));
|
||||
GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t));
|
||||
GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t));
|
||||
GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t));
|
||||
GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(int8_t));
|
||||
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
|
||||
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(ExpandXType));
|
||||
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
|
||||
|
||||
if constexpr (EXEC_FLAG == 0) {
|
||||
if constexpr (g_coreType == AscendC::AIV) {
|
||||
AscendC::TPipe tpipe;
|
||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
|
||||
dispatcher;
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||
dispatcher.Process();
|
||||
tpipe.Destroy();
|
||||
icache_preload(8);
|
||||
}
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
Arch::CrossCoreFlag gmm1AivFinished{0};
|
||||
if constexpr (g_coreType == AscendC::AIV) {
|
||||
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
|
||||
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
|
||||
} else {
|
||||
Arch::CrossCoreWaitFlag(gmm1AivFinished);
|
||||
}
|
||||
}
|
||||
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
|
||||
Gmm1BlockScheduler>(
|
||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||
gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2,
|
||||
layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
|
||||
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_);
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
Arch::CrossCoreFlag gmm1AivFinished{0};
|
||||
if constexpr (g_coreType == AscendC::AIV) {
|
||||
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
|
||||
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
|
||||
} else {
|
||||
Arch::CrossCoreWaitFlag(gmm1AivFinished);
|
||||
}
|
||||
|
||||
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
|
||||
if (g_coreType == AscendC::AIV) {
|
||||
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
|
||||
workspaceGM_, nullptr, tilingData_);
|
||||
}
|
||||
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
|
||||
Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2,
|
||||
gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut,
|
||||
layoutOutput, gmWorkspace, &combiner);
|
||||
}
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_H
|
||||
@@ -0,0 +1,14 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/epilogue/block/block_epilogue.hpp"
|
||||
|
||||
#include "block_epilogue_per_token_dequant_swiglu.h"
|
||||
#include "block_epilogue_per_token_dequant.hpp"
|
||||
@@ -0,0 +1,760 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP
|
||||
#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP
|
||||
|
||||
#include "../../raw_distributed/cam_moe_distribute_combine.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
|
||||
#define ENABLE_EP_SEND_COUNT_HASH 0
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class ScaleType_, class PerTokenScaleType_,
|
||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
||||
class TileCopy_, class EpilogueTileSwizzle_>
|
||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, ScaleType_, PerTokenScaleType_,
|
||||
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
|
||||
EpilogueTileSwizzle_>
|
||||
{
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementScale = typename ScaleType_::Element;
|
||||
using LayoutScale = typename ScaleType_::Layout;
|
||||
using ElementPerTokenScale = typename PerTokenScaleType_::Element;
|
||||
using LayoutPerTokenScale = typename PerTokenScaleType_::Layout;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(std::is_same_v<ElementC, int32_t> &&
|
||||
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>) &&
|
||||
std::is_same_v<ElementScale, ElementD> && std::is_same_v<ElementPerTokenScale, ElementD>,
|
||||
"The element type template parameters of BlockEpilogue are wrong");
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong");
|
||||
|
||||
// Tile compute ops
|
||||
using TileRowBroadcastMul = TileRowBroadcastMul_;
|
||||
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
|
||||
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
|
||||
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
|
||||
|
||||
using TileShape = typename TileRowBroadcastMul::TileShape;
|
||||
|
||||
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
|
||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
||||
"TileShape must be consistent for all tile compute ops");
|
||||
|
||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||
(TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) +
|
||||
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
|
||||
"TileShape is too large to fit in UB");
|
||||
|
||||
struct Params {
|
||||
__gm__ ElementScale *ptrScale{nullptr};
|
||||
LayoutScale layoutScale{};
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||
LayoutPerTokenScale layoutPerTokenScale{};
|
||||
__gm__ ElementD *ptrD{nullptr};
|
||||
LayoutD layoutD{};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
||||
: ptrScale(ptrScale_),
|
||||
layoutScale(layoutScale_),
|
||||
ptrPerTokenScale(ptrPerTokenScale_),
|
||||
layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(ptrD_),
|
||||
layoutD(layoutD_)
|
||||
{}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementD);
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COLUMN * sizeof(float);
|
||||
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::ROW * sizeof(float);
|
||||
ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::ROW * BYTE_PER_BLK;
|
||||
ubPerTokenMul = ubMul;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
|
||||
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
|
||||
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
|
||||
{
|
||||
if (actualBlockShapeMNK.k() == 0) {
|
||||
return;
|
||||
}
|
||||
callback();
|
||||
|
||||
// Calculate the offset of the current block
|
||||
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
|
||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||
|
||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
||||
AscendC::GlobalTensor<ElementD> gmD;
|
||||
gmD.SetGlobalBuffer(params.ptrD);
|
||||
|
||||
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
|
||||
auto tileShape = TileShape::ToCoord();
|
||||
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
|
||||
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
|
||||
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
|
||||
uint32_t subblockNum = AscendC::GetSubBlockNum();
|
||||
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
|
||||
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
|
||||
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
|
||||
auto tileOffsetInBlock = tileCoord * tileShape;
|
||||
auto tileOffset = blockOffset + tileOffsetInBlock;
|
||||
|
||||
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
|
||||
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
LayoutC layoutUbC{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
|
||||
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
|
||||
|
||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
||||
|
||||
auto &ubScale = ubScaleList[ubListId];
|
||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
|
||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
||||
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
|
||||
|
||||
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
|
||||
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
|
||||
|
||||
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
|
||||
auto layoutUbPerTokenScale =
|
||||
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
|
||||
layoutGmTilePerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
|
||||
tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32);
|
||||
tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
auto &ubD = ubDList[ubListId];
|
||||
LayoutD layoutUbD{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
|
||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32;
|
||||
AscendC::LocalTensor<float> ubScaleFp32;
|
||||
AscendC::LocalTensor<float> ubMul;
|
||||
AscendC::LocalTensor<float> ubPerTokenScaleFp32;
|
||||
AscendC::LocalTensor<float> ubPerTokenScaleFp32Brcb;
|
||||
AscendC::LocalTensor<float> ubPerTokenMul;
|
||||
|
||||
TileRowBroadcastMul tileRowBroadcastMul;
|
||||
TileBroadcastOneBlk tileBroadcastOneBlk;
|
||||
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyGmToUbScale copyGmToUbScale;
|
||||
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
|
||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
||||
class TileCopy_, class EpilogueTileSwizzle_>
|
||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, Gemm::GemmType<float, LayoutScale_>,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_,
|
||||
TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_>
|
||||
{
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementScale = float;
|
||||
using LayoutScale = LayoutScale_;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(std::is_same_v<ElementC, int32_t> &&
|
||||
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
|
||||
"The element type template parameters of BlockEpilogue are wrong");
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong");
|
||||
|
||||
// Tile compute ops
|
||||
using TileRowBroadcastMul = TileRowBroadcastMul_;
|
||||
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
|
||||
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
|
||||
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
|
||||
|
||||
using TileShape = typename TileRowBroadcastMul::TileShape;
|
||||
|
||||
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
|
||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
||||
"TileShape must be consistent for all tile compute ops");
|
||||
|
||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||
ArchTag::UB_SIZE,
|
||||
"TileShape is too large to fit in UB");
|
||||
|
||||
struct Params {
|
||||
__gm__ ElementScale *ptrScale{nullptr};
|
||||
LayoutScale layoutScale{};
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||
LayoutPerTokenScale layoutPerTokenScale{};
|
||||
__gm__ ElementD *ptrD{nullptr};
|
||||
LayoutD layoutD{};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
||||
: ptrScale(ptrScale_),
|
||||
layoutScale(layoutScale_),
|
||||
ptrPerTokenScale(ptrPerTokenScale_),
|
||||
layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(ptrD_),
|
||||
layoutD(layoutD_)
|
||||
{}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE void AlignUbOffset()
|
||||
{
|
||||
size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1);
|
||||
if (ubMask != 0) {
|
||||
ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo,
|
||||
Params const ¶ms = Params{})
|
||||
: resource(resource), calcInfo(calcInfo), params(params)
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementD);
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::ROW * BYTE_PER_BLK;
|
||||
ubPerTokenMul = ubCFp32;
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AlignUbOffset();
|
||||
epSendCountLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
|
||||
ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t);
|
||||
AlignUbOffset();
|
||||
AscendC::GlobalTensor<int32_t> epSendCountGM;
|
||||
epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_);
|
||||
uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_;
|
||||
AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast<uint32_t>(epSendCountSize * sizeof(uint32_t)),
|
||||
0U, 0U, 0U};
|
||||
AscendC::DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
|
||||
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
||||
#if ENABLE_EP_SEND_COUNT_HASH
|
||||
tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
|
||||
uint32_t maxGroupSendCount = 0;
|
||||
uint32_t groupSendCount = 0;
|
||||
for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) {
|
||||
uint32_t prevGroupSendCount = groupSendCount;
|
||||
groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1);
|
||||
if (maxGroupSendCount < groupSendCount - prevGroupSendCount) {
|
||||
maxGroupSendCount = groupSendCount - prevGroupSendCount;
|
||||
}
|
||||
}
|
||||
ubOffset += maxGroupSendCount * sizeof(int32_t);
|
||||
AlignUbOffset();
|
||||
// assert: ubOffset <= AscendC::TOTAL_UB_SIZE or
|
||||
// AscendC::TOTAL_VEC_LOCAL_SIZE
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U)
|
||||
{
|
||||
return (GM_ADDR)((calcInfo.epRankId_ == rankId)
|
||||
? calcInfo.epWinContext_->localWindowsIn
|
||||
: ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsIn) +
|
||||
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
|
||||
}
|
||||
#if ENABLE_EP_SEND_COUNT_HASH
|
||||
CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen)
|
||||
{
|
||||
constexpr uint32_t DUPLICATE_MASK_COUNT = 8;
|
||||
uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1));
|
||||
if (hashOffsetMask != 0) {
|
||||
uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask;
|
||||
if (copyLen < remainMaskCount) {
|
||||
remainMaskCount = copyLen;
|
||||
}
|
||||
uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask;
|
||||
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1,
|
||||
DUPLICATE_MASK_COUNT);
|
||||
hashOffset += remainMaskCount;
|
||||
copyLen -= remainMaskCount;
|
||||
}
|
||||
if (copyLen > 0) {
|
||||
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen);
|
||||
hashOffset += copyLen;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
|
||||
{
|
||||
if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) {
|
||||
remoteEpRank = calcInfo.epRankId_;
|
||||
localEpRank = epRank;
|
||||
} else {
|
||||
remoteEpRank = epRank;
|
||||
localEpRank = calcInfo.epRankId_;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor<ElementD> &ubD, layout::RowMajor &layoutGmTileD,
|
||||
LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD)
|
||||
{
|
||||
const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD);
|
||||
const uint32_t copyTokenSrcStride =
|
||||
(layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD));
|
||||
const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD);
|
||||
|
||||
int64_t offsetD = groupOffsetD + tileOffsetD;
|
||||
uint32_t startToken = offsetD / calcInfo.axisH_;
|
||||
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
|
||||
uint32_t itToken = startToken;
|
||||
uint32_t endToken = startToken + layoutGmTileD.shape(0);
|
||||
#if ENABLE_EP_SEND_COUNT_HASH
|
||||
uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken);
|
||||
#else
|
||||
constexpr uint32_t epRankStart = 0;
|
||||
#endif
|
||||
uint32_t sendCount =
|
||||
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
|
||||
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
|
||||
uint32_t prevSendCount = sendCount;
|
||||
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
|
||||
if (prevSendCount <= itToken && itToken < sendCount) {
|
||||
uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken;
|
||||
AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride,
|
||||
copyTokenDstStride, 0);
|
||||
uint32_t remoteEpRank;
|
||||
uint32_t localEpRank;
|
||||
SetCombineSendEpRank(epRank, remoteEpRank, localEpRank);
|
||||
GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) +
|
||||
localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_;
|
||||
AscendC::GlobalTensor<ElementD> rankWindow;
|
||||
rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM);
|
||||
AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset],
|
||||
ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams);
|
||||
itToken += copyTokenCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK,
|
||||
GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK,
|
||||
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutBlockC,
|
||||
Callback &&callback = Callback{})
|
||||
{
|
||||
if (actualBlockShapeMNK.k() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
expertOffset = expertIdx * calcInfo.epWorldSize_;
|
||||
#if ENABLE_EP_SEND_COUNT_HASH
|
||||
if (currentExpertIdx_ != expertIdx) {
|
||||
uint32_t hashOffset = 0;
|
||||
uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1);
|
||||
for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) {
|
||||
uint32_t prevSendCount = sendCount;
|
||||
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
|
||||
InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(eventVS);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(eventVS);
|
||||
currentExpertIdx_ = expertIdx;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
callback();
|
||||
// Calculate the offset of the current block
|
||||
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
|
||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||
|
||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
||||
AscendC::GlobalTensor<ElementD> gmD;
|
||||
gmD.SetGlobalBuffer(params.ptrD);
|
||||
|
||||
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
|
||||
auto tileShape = TileShape::ToCoord();
|
||||
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
|
||||
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
|
||||
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
|
||||
uint32_t subblockNum = AscendC::GetSubBlockNum();
|
||||
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
|
||||
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
|
||||
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
|
||||
auto tileOffsetInBlock = tileCoord * tileShape;
|
||||
auto tileOffset = blockOffset + tileOffsetInBlock;
|
||||
|
||||
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
|
||||
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
LayoutC layoutUbC{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
|
||||
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
|
||||
|
||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
||||
|
||||
auto &ubScale = ubScaleList[ubListId];
|
||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
|
||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
||||
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
|
||||
|
||||
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
|
||||
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
|
||||
|
||||
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
|
||||
auto layoutUbPerTokenScale =
|
||||
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
|
||||
layoutGmTilePerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
tileRowBroadcastMul(ubMul, ubCFp32, ubScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
auto &ubD = ubDList[ubListId];
|
||||
LayoutD layoutUbD{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
auto tileOffsetD = params.layoutD.GetOffset(tileOffset);
|
||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD);
|
||||
} else {
|
||||
auto gmTileD = gmD[tileOffsetD];
|
||||
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
Arch::Resource<ArchTag> &resource;
|
||||
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
AscendC::LocalTensor<int32_t> epSendCountLocal_;
|
||||
#if ENABLE_EP_SEND_COUNT_HASH
|
||||
AscendC::LocalTensor<int32_t> tokenToEpRankHashLocal_;
|
||||
uint32_t currentExpertIdx_{static_cast<uint32_t>(-1)};
|
||||
#endif
|
||||
|
||||
size_t ubOffset{0};
|
||||
int32_t eventVMTE2{0};
|
||||
int32_t eventMTE2V{0};
|
||||
int32_t eventMTE3V{0};
|
||||
int32_t eventVMTE3{0};
|
||||
int32_t eventVS{0};
|
||||
int32_t eventMTE2S{0};
|
||||
|
||||
uint32_t expertOffset;
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32;
|
||||
AscendC::LocalTensor<float> ubMul;
|
||||
AscendC::LocalTensor<float> ubPerTokenScaleBrcb;
|
||||
AscendC::LocalTensor<float> ubPerTokenMul;
|
||||
|
||||
TileRowBroadcastMul tileRowBroadcastMul;
|
||||
TileBroadcastOneBlk tileBroadcastOneBlk;
|
||||
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyGmToUbScale copyGmToUbScale;
|
||||
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
|
||||
#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP
|
||||
@@ -0,0 +1,326 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
|
||||
#include "../tile/tile_stride_muls.h"
|
||||
#include "../tile/tile_stride_binary.h"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
|
||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
||||
class TileCopy_, class EpilogueTileSwizzle_>
|
||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>, CType_,
|
||||
Gemm::GemmType<float, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
|
||||
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
|
||||
EpilogueTileSwizzle_>
|
||||
{
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementScale = float;
|
||||
using LayoutScale = LayoutScale_;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(std::is_same_v<ElementC, int32_t> && std::is_same_v<ElementD, float>,
|
||||
"The element type template parameters of BlockEpilogue are wrong");
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
|
||||
std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong");
|
||||
|
||||
// Tile compute ops
|
||||
using TileRowBroadcastMul = TileRowBroadcastMul_;
|
||||
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
|
||||
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
|
||||
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
|
||||
|
||||
using TileShape = typename TileRowBroadcastMul::TileShape;
|
||||
static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0,
|
||||
"The per token scale granularity for word calculation must be 32 bytes aligned.");
|
||||
static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts.");
|
||||
|
||||
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
|
||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
||||
"TileShape must be consistent for all tile compute ops");
|
||||
|
||||
static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2;
|
||||
using ChunkTileShape = MatrixShape<TileShape::ROW, CHUNK_TILE_COLUMN>;
|
||||
|
||||
using TileStrideMuls = Tile::TileStrideMuls<ArchTag, float, ChunkTileShape, ChunkTileShape, TileShape>;
|
||||
using TileStrideDiv = Tile::TileStrideDiv<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
|
||||
ChunkTileShape::COLUMN>;
|
||||
using TileStrideMul = Tile::TileStrideMul<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
|
||||
ChunkTileShape::COLUMN>;
|
||||
|
||||
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
|
||||
|
||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||
(TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||
ArchTag::UB_SIZE,
|
||||
"TileShape is too large to fit in UB");
|
||||
|
||||
struct Params {
|
||||
__gm__ ElementScale *ptrScale{nullptr};
|
||||
LayoutScale layoutScale{};
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||
LayoutPerTokenScale layoutPerTokenScale{};
|
||||
__gm__ ElementD *ptrD{nullptr};
|
||||
LayoutD layoutD{};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
||||
: ptrScale(ptrScale_),
|
||||
layoutScale(layoutScale_),
|
||||
ptrPerTokenScale(ptrPerTokenScale_),
|
||||
layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(ptrD_),
|
||||
layoutD(layoutD_)
|
||||
{}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(ElementD);
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
|
||||
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
ubTmpMxN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::COUNT * sizeof(float);
|
||||
ubTmpMx32B = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += TileShape::ROW * BYTE_PER_BLK;
|
||||
ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
|
||||
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
|
||||
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
|
||||
{
|
||||
if (0 == actualBlockShapeMNK.k()) {
|
||||
return;
|
||||
}
|
||||
callback();
|
||||
// Calculate the offset of the current block
|
||||
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
|
||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||
|
||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
||||
AscendC::GlobalTensor<ElementD> gmD;
|
||||
gmD.SetGlobalBuffer(params.ptrD);
|
||||
|
||||
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
|
||||
auto ubChunkTileStride = MakeCoord(static_cast<int64_t>(ChunkTileShape::COLUMN), 1L);
|
||||
auto tileShape = TileShape::ToCoord();
|
||||
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
|
||||
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
|
||||
uint32_t subblockIdx = 0; // for 1C1V
|
||||
uint32_t subblockNum = 1; // for 1C1V
|
||||
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
|
||||
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
|
||||
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
|
||||
auto tileOffsetInBlock = tileCoord * tileShape;
|
||||
auto tileOffset = blockOffset + tileOffsetInBlock;
|
||||
|
||||
auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1);
|
||||
auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1);
|
||||
|
||||
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
|
||||
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
LayoutC layoutUbC{actualTileShape, ubTileStride};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
|
||||
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
|
||||
|
||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
||||
|
||||
auto &ubScale = ubScaleList[ubListId];
|
||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
|
||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
||||
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
|
||||
|
||||
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
|
||||
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
|
||||
|
||||
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
|
||||
auto layoutUbPerTokenScale =
|
||||
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
|
||||
layoutGmTilePerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
auto &ubD = ubDList[ubListId];
|
||||
LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride};
|
||||
|
||||
auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN];
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)];
|
||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
|
||||
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubTmpMxN;
|
||||
AscendC::LocalTensor<float> ubTmpMx32B;
|
||||
AscendC::LocalTensor<float> ubTmpMxChunkN;
|
||||
|
||||
TileRowBroadcastMul tileRowBroadcastMul;
|
||||
TileBroadcastOneBlk tileBroadcastOneBlk;
|
||||
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
|
||||
|
||||
TileStrideMuls tileStrideMuls;
|
||||
TileStrideDiv tileStrideDiv;
|
||||
TileStrideMul tileStrideMul;
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyGmToUbScale copyGmToUbScale;
|
||||
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
|
||||
namespace Catlass::Epilogue {
|
||||
|
||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
|
||||
struct EpilogueAtlasA2PerTokenDequantSwiglu {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
|
||||
struct EpilogueAtlasA2PerTokenDequantCombine {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue
|
||||
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/catlass.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Tile {
|
||||
|
||||
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
|
||||
int64_t SRC1_STRIDE_>
|
||||
struct TileStrideBinary {
|
||||
using ArchTag = ArchTag_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TileShape = TileShape_;
|
||||
static constexpr int64_t DST_STRIDE = DST_STRIDE_;
|
||||
static constexpr int64_t SRC0_STRIDE = SRC0_STRIDE_;
|
||||
static constexpr int64_t SRC1_STRIDE = SRC1_STRIDE_;
|
||||
|
||||
static constexpr uint32_t MAX_REPEAT_TIMES = 255;
|
||||
static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementCompute);
|
||||
|
||||
static constexpr uint32_t DST_BLK_NUM_PER_COLUMN = DST_STRIDE / ELE_NUM_PER_BLK;
|
||||
static constexpr uint32_t SRC0_BLK_NUM_PER_COLUMN = SRC0_STRIDE / ELE_NUM_PER_BLK;
|
||||
static constexpr uint32_t SRC1_BLK_NUM_PER_COLUMN = SRC1_STRIDE / ELE_NUM_PER_BLK;
|
||||
|
||||
static constexpr uint32_t ROW_NUM_PER_COMPUTE = MAX_REPEAT_TIMES;
|
||||
static constexpr uint32_t COL_NUM_PER_COMPUTE = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute);
|
||||
|
||||
CATLASS_DEVICE
|
||||
TileStrideBinary()
|
||||
{
|
||||
repeatParams.dstBlkStride = 1;
|
||||
repeatParams.src0BlkStride = 1;
|
||||
repeatParams.src1BlkStride = 1;
|
||||
repeatParams.dstRepStride = DST_BLK_NUM_PER_COLUMN;
|
||||
repeatParams.src0RepStride = SRC0_BLK_NUM_PER_COLUMN;
|
||||
repeatParams.src1RepStride = SRC1_BLK_NUM_PER_COLUMN;
|
||||
}
|
||||
|
||||
AscendC::BinaryRepeatParams repeatParams;
|
||||
};
|
||||
|
||||
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
|
||||
int64_t SRC1_STRIDE_>
|
||||
struct TileStrideMul
|
||||
: TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_> {
|
||||
using Base = TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_>;
|
||||
|
||||
CATLASS_DEVICE
|
||||
TileStrideMul() : Base() {}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::LocalTensor<typename Base::ElementCompute> const &ubDst,
|
||||
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc0,
|
||||
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc1)
|
||||
{
|
||||
for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) {
|
||||
uint32_t residueM = Base::TileShape::ROW - rowOffset;
|
||||
uint8_t repeatTimes =
|
||||
static_cast<uint8_t>((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM);
|
||||
for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) {
|
||||
uint32_t residueN = Base::TileShape::COLUMN - colOffset;
|
||||
uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN;
|
||||
AscendC::Mul(ubDst[rowOffset * Base::DST_STRIDE + colOffset],
|
||||
ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset],
|
||||
ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
|
||||
int64_t SRC1_STRIDE_>
|
||||
struct TileStrideDiv
|
||||
: TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_> {
|
||||
using Base = TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_>;
|
||||
|
||||
CATLASS_DEVICE
|
||||
TileStrideDiv() : Base() {}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::LocalTensor<typename Base::ElementCompute> const &ubDst,
|
||||
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc0,
|
||||
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc1)
|
||||
{
|
||||
for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) {
|
||||
uint32_t residueM = Base::TileShape::ROW - rowOffset;
|
||||
uint8_t repeatTimes =
|
||||
static_cast<uint8_t>((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM);
|
||||
for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) {
|
||||
uint32_t residueN = Base::TileShape::COLUMN - colOffset;
|
||||
uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN;
|
||||
AscendC::Div(ubDst[rowOffset * Base::DST_STRIDE + colOffset],
|
||||
ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset],
|
||||
ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Tile
|
||||
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/catlass.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Tile {
|
||||
|
||||
template <class ArchTag_, class ElementCompute_, class TileShape_, class DstTileShape_, class SrcTileShape_>
|
||||
struct TileStrideMuls {
|
||||
using ArchTag = ArchTag_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TileShape = TileShape_;
|
||||
using DstTileShape = DstTileShape_;
|
||||
using SrcTileShape = SrcTileShape_;
|
||||
|
||||
static_assert(DstTileShape::ROW == SrcTileShape::ROW && DstTileShape::ROW == TileShape::ROW, "Error");
|
||||
|
||||
CATLASS_DEVICE
|
||||
TileStrideMuls() {}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::LocalTensor<ElementCompute> const &ubDst,
|
||||
AscendC::LocalTensor<ElementCompute> const &ubSrc, ElementCompute scalar)
|
||||
{
|
||||
constexpr uint32_t maxRepeatTimes = 255;
|
||||
constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute);
|
||||
|
||||
constexpr uint32_t dstBlkNumPerColumn = DstTileShape::COLUMN / eleNumPerBlk;
|
||||
constexpr uint32_t srcBlkNumPerColumn = SrcTileShape::COLUMN / eleNumPerBlk;
|
||||
AscendC::UnaryRepeatParams repeatParams;
|
||||
repeatParams.dstBlkStride = 1;
|
||||
repeatParams.srcBlkStride = 1;
|
||||
repeatParams.dstRepStride = dstBlkNumPerColumn;
|
||||
repeatParams.srcRepStride = srcBlkNumPerColumn;
|
||||
|
||||
constexpr uint32_t rowNumPerCompute = maxRepeatTimes;
|
||||
constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute);
|
||||
for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) {
|
||||
uint32_t residueM = TileShape::ROW - rowOffset;
|
||||
uint8_t repeatTimes = static_cast<uint8_t>((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM);
|
||||
for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) {
|
||||
uint32_t residueN = TileShape::COLUMN - colOffset;
|
||||
uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN;
|
||||
AscendC::Muls(ubDst[rowOffset * DstTileShape::COLUMN + colOffset],
|
||||
ubSrc[rowOffset * SrcTileShape::COLUMN + colOffset], scalar, mask, repeatTimes,
|
||||
repeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Tile
|
||||
@@ -0,0 +1,13 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/gemm/block/block_mmad.hpp"
|
||||
|
||||
#include "block_mmad_preload_async_with_callback_resident_a.h"
|
||||
@@ -0,0 +1,420 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/coord.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
#include "catlass/gemm/helper.hpp"
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
template <uint32_t PRELOAD_STAGES_, uint32_t L1A_STAGES_, uint32_t L1B_STAGES_, uint32_t L0A_STAGES_,
|
||||
uint32_t L0B_STAGES_, uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_,
|
||||
class L1TileShape_, class L0TileShape_, class AType_, class BType_, class CType_, class BiasType_,
|
||||
class TileCopy_, class TileMmad_>
|
||||
struct BlockMmad<
|
||||
MmadAtlasA2PreloadAsyncWithCallbackResidentA<PRELOAD_STAGES_, L1A_STAGES_, L1B_STAGES_, L0A_STAGES_, L0B_STAGES_,
|
||||
L0C_STAGES_, ENABLE_UNIT_FLAG_, ENABLE_SHUFFLE_K_>,
|
||||
L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> {
|
||||
public:
|
||||
// Type Aliases
|
||||
using DispatchPolicy =
|
||||
MmadAtlasA2PreloadAsyncWithCallbackResidentA<PRELOAD_STAGES_, L1A_STAGES_, L1B_STAGES_, L0A_STAGES_,
|
||||
L0B_STAGES_, L0C_STAGES_, ENABLE_UNIT_FLAG_, ENABLE_SHUFFLE_K_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using L1TileShape = L1TileShape_;
|
||||
using L0TileShape = L0TileShape_;
|
||||
using ElementA = typename AType_::Element;
|
||||
using LayoutA = typename AType_::Layout;
|
||||
using ElementB = typename BType_::Element;
|
||||
using LayoutB = typename BType_::Layout;
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using TileMmad = TileMmad_;
|
||||
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
|
||||
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
|
||||
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
|
||||
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
|
||||
using CopyL0CToGm = typename TileCopy_::CopyL0CToGm;
|
||||
using ElementAccumulator =
|
||||
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
|
||||
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
|
||||
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
|
||||
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
|
||||
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
|
||||
using LayoutCInL0 = layout::zN;
|
||||
|
||||
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
|
||||
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
|
||||
|
||||
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
|
||||
static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES;
|
||||
static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES;
|
||||
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
|
||||
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
|
||||
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
|
||||
|
||||
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
|
||||
|
||||
// L1 tile size
|
||||
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
|
||||
// L0 tile size
|
||||
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
|
||||
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
|
||||
|
||||
// Check LayoutC
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
|
||||
|
||||
// Check L1TileShape
|
||||
static_assert(L1A_TILE_SIZE * L1A_STAGES + L1B_TILE_SIZE * L1B_STAGES <= ArchTag::L1_SIZE,
|
||||
"L1TileShape exceeding the L1 space!");
|
||||
|
||||
// Check L0TileShape
|
||||
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
|
||||
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
|
||||
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
|
||||
|
||||
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
|
||||
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
|
||||
|
||||
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(L1TileShape::M, L1TileShape::K);
|
||||
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(L1TileShape::K, L1TileShape::N);
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
||||
{
|
||||
InitL1(resource, l1BufAddrStart);
|
||||
InitL0A(resource);
|
||||
InitL0B(resource);
|
||||
InitL0C(resource);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockMmad()
|
||||
{
|
||||
SynchronizeBlock();
|
||||
for (uint32_t i = 0; i < L1A_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L1B_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
|
||||
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
|
||||
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
|
||||
GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe,
|
||||
Callback const &callbackAfterFixpipe)
|
||||
{
|
||||
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
|
||||
bool useResidentA =
|
||||
(kTileCount == L1A_STAGES) && (!isFirstLoad) && (gmBlockA.GetPhyAddr() == lastGmBlockA.GetPhyAddr());
|
||||
isFirstLoad = false;
|
||||
lastGmBlockA = gmBlockA;
|
||||
|
||||
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
|
||||
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
|
||||
|
||||
uint32_t startTileIdx = 0;
|
||||
if constexpr (ENABLE_SHUFFLE_K) {
|
||||
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
|
||||
}
|
||||
|
||||
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
|
||||
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx)
|
||||
: (startTileIdx + kLoopIdx - kTileCount);
|
||||
|
||||
uint32_t kActual =
|
||||
(kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
|
||||
|
||||
// Emission load instruction from GM to L1
|
||||
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
|
||||
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
|
||||
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
|
||||
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
|
||||
// Load first matrix A tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1AListId]);
|
||||
if (!useResidentA) {
|
||||
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
|
||||
copyGmToL1A(l1ATensorList[l1AListId], gmTileA, L1A_LAYOUT, layoutTileA);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1AListId]);
|
||||
// Load first matrix B tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1BListId]);
|
||||
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
|
||||
copyGmToL1B(l1BTensorList[l1BListId], gmTileB, L1B_LAYOUT, layoutTileB);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1BListId]);
|
||||
|
||||
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
|
||||
if (preloadCount == PRELOAD_STAGES) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
}
|
||||
|
||||
// Store the current load status
|
||||
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES)
|
||||
? (l1TileMmadParamsId + preloadCount)
|
||||
: (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
|
||||
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
|
||||
l1TileMmadParams.l1AListId = l1AListId;
|
||||
l1TileMmadParams.l1BListId = l1BListId;
|
||||
l1TileMmadParams.mRound = mRound;
|
||||
l1TileMmadParams.nRound = nRound;
|
||||
l1TileMmadParams.kActual = kActual;
|
||||
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
|
||||
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
|
||||
if (kLoopIdx == kTileCount - 1) {
|
||||
l1TileMmadParams.gmBlockC = gmBlockC;
|
||||
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
|
||||
l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe;
|
||||
l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe;
|
||||
}
|
||||
|
||||
if (preloadCount < PRELOAD_STAGES) {
|
||||
++preloadCount;
|
||||
} else {
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
}
|
||||
l1AListId = (l1AListId + 1 < L1A_STAGES) ? (l1AListId + 1) : 0;
|
||||
l1BListId = (l1BListId + 1 < L1B_STAGES) ? (l1BListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void SynchronizeBlock()
|
||||
{
|
||||
while (preloadCount > 0) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
--preloadCount;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct L1TileMmadParams {
|
||||
uint32_t l1AListId;
|
||||
uint32_t l1BListId;
|
||||
uint32_t mRound;
|
||||
uint32_t nRound;
|
||||
uint32_t kActual;
|
||||
bool isKLoopFirst;
|
||||
bool isKLoopLast;
|
||||
AscendC::GlobalTensor<ElementC> gmBlockC;
|
||||
LayoutC layoutCInGm;
|
||||
Callback callbackBeforeFixpipe;
|
||||
Callback callbackAfterFixpipe;
|
||||
|
||||
CATLASS_DEVICE
|
||||
L1TileMmadParams() = default;
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
|
||||
{
|
||||
uint32_t l1AOffset = l1BufAddrStart;
|
||||
for (uint32_t i = 0; i < L1A_STAGES; ++i) {
|
||||
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
|
||||
l1AEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
}
|
||||
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1A_STAGES;
|
||||
for (uint32_t i = 0; i < L1B_STAGES; ++i) {
|
||||
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
|
||||
l1BEventList[i] = i + L1A_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0A(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
|
||||
l0AEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0B(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
|
||||
l0BEventList[i] = i + L0A_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0C(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
|
||||
l0CEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void L1TileMmad(L1TileMmadParams const ¶ms)
|
||||
{
|
||||
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
|
||||
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
|
||||
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
|
||||
auto &l1ATensor = l1ATensorList[params.l1AListId];
|
||||
auto &l1BTensor = l1BTensorList[params.l1BListId];
|
||||
|
||||
auto &l0CTensor = l0CTensorList[l0CListId];
|
||||
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
|
||||
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopFirst) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
|
||||
uint32_t mPartActual =
|
||||
(mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
|
||||
|
||||
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
|
||||
uint32_t kPartActual =
|
||||
(kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
|
||||
|
||||
auto &l0ATile = l0ATensorList[l0AListId];
|
||||
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
|
||||
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
|
||||
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
if ((mPartIdx == 0) && (kPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1AListId]);
|
||||
}
|
||||
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
|
||||
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1AListId]);
|
||||
}
|
||||
|
||||
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
|
||||
uint32_t nPartActual =
|
||||
(nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
|
||||
|
||||
auto &l0BTile = l0BTensorList[l0BListId];
|
||||
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
|
||||
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
|
||||
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
if ((kPartIdx == 0) && (nPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1BListId]);
|
||||
}
|
||||
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
|
||||
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1BListId]);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
|
||||
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
|
||||
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
|
||||
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
|
||||
// If the unit flag is enabled, the unit flag is set according to the calculation progress
|
||||
uint8_t unitFlag = 0b00;
|
||||
if constexpr (ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) &&
|
||||
(nPartIdx == nPartLoop - 1)) {
|
||||
unitFlag = 0b11;
|
||||
} else {
|
||||
unitFlag = 0b10;
|
||||
}
|
||||
}
|
||||
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.isKLoopLast) {
|
||||
auto layoutCInGm = params.layoutCInGm;
|
||||
|
||||
params.callbackBeforeFixpipe();
|
||||
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
} else {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
|
||||
}
|
||||
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
|
||||
|
||||
params.callbackAfterFixpipe();
|
||||
}
|
||||
}
|
||||
|
||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1A_STAGES];
|
||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1B_STAGES];
|
||||
int32_t l1AEventList[L1A_STAGES];
|
||||
int32_t l1BEventList[L1B_STAGES];
|
||||
uint32_t l1AListId{0};
|
||||
uint32_t l1BListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
|
||||
int32_t l0AEventList[L0A_STAGES];
|
||||
uint32_t l0AListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
|
||||
int32_t l0BEventList[L0B_STAGES];
|
||||
uint32_t l0BListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
|
||||
int32_t l0CEventList[L0C_STAGES_];
|
||||
uint32_t l0CListId{0};
|
||||
|
||||
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
|
||||
uint32_t l1TileMmadParamsId{0};
|
||||
uint32_t preloadCount{0};
|
||||
|
||||
TileMmad tileMmad;
|
||||
CopyGmToL1A copyGmToL1A;
|
||||
CopyGmToL1B copyGmToL1B;
|
||||
CopyL1ToL0A copyL1ToL0A;
|
||||
CopyL1ToL0B copyL1ToL0B;
|
||||
CopyL0CToGm copyL0CToGm;
|
||||
|
||||
bool isFirstLoad{true};
|
||||
AscendC::GlobalTensor<ElementA> lastGmBlockA;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Block
|
||||
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace Catlass::Gemm {
|
||||
|
||||
template <uint32_t PRELOAD_STAGES_, uint32_t L1A_STAGES_, uint32_t L1B_STAGES_, uint32_t L0A_STAGES_,
|
||||
uint32_t L0B_STAGES_, uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
|
||||
struct MmadAtlasA2PreloadAsyncWithCallbackResidentA : public MmadAtlasA2Async {
|
||||
static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance
|
||||
static constexpr uint32_t L1A_STAGES = L1A_STAGES_;
|
||||
static constexpr uint32_t L1B_STAGES = L1B_STAGES_;
|
||||
static constexpr uint32_t L0A_STAGES = L0A_STAGES_;
|
||||
static constexpr uint32_t L0B_STAGES = L0B_STAGES_;
|
||||
static constexpr uint32_t L0C_STAGES = L0C_STAGES_;
|
||||
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm
|
||||
@@ -0,0 +1,355 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
|
||||
#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
|
||||
|
||||
#include "../../raw_distributed/cam_moe_distribute_combine.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/cross_core_sync.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/coord.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
|
||||
namespace Catlass::Gemm::Kernel {
|
||||
|
||||
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_,
|
||||
uint32_t WORKSPACE_STAGES_, class ElementGroupList_>
|
||||
class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace
|
||||
{
|
||||
public:
|
||||
using BlockMmad = BlockMmad_;
|
||||
using ArchTag = typename BlockMmad::ArchTag;
|
||||
using L1TileShape = typename BlockMmad::L1TileShape;
|
||||
using ElementA = typename BlockMmad::ElementA;
|
||||
using LayoutA = typename BlockMmad::LayoutA;
|
||||
using ElementB = typename BlockMmad::ElementB;
|
||||
using LayoutB = typename BlockMmad::LayoutB;
|
||||
using ElementC = typename BlockMmad::ElementC;
|
||||
using LayoutC = typename BlockMmad::LayoutC;
|
||||
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
||||
|
||||
using BlockEpilogue = BlockEpilogue_;
|
||||
using ElementScale = typename BlockEpilogue::ElementScale;
|
||||
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
||||
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
||||
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
||||
using ElementD = typename BlockEpilogue::ElementD;
|
||||
using LayoutD = typename BlockEpilogue::LayoutD;
|
||||
using EpilogueParams = typename BlockEpilogue::Params;
|
||||
|
||||
using BlockScheduler = BlockScheduler_;
|
||||
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
|
||||
using ElementGroupList = ElementGroupList_;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
// Data members
|
||||
GemmCoord problemShape;
|
||||
uint32_t problemCount;
|
||||
__gm__ ElementGroupList_ *ptrGroupList;
|
||||
__gm__ ElementA *ptrA;
|
||||
LayoutA layoutA;
|
||||
__gm__ ElementB *ptrB;
|
||||
LayoutB layoutB;
|
||||
__gm__ ElementScale *ptrScale;
|
||||
LayoutScale layoutScale;
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale;
|
||||
LayoutPerTokenScale layoutPerTokenScale;
|
||||
__gm__ ElementD *ptrD;
|
||||
LayoutD layoutD;
|
||||
GM_ADDR ptrWorkspace;
|
||||
void *combiner;
|
||||
|
||||
// Methods
|
||||
CATLASS_DEVICE
|
||||
Params() {}
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_,
|
||||
GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_,
|
||||
LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_,
|
||||
void *combiner_)
|
||||
: problemShape(problemShape_),
|
||||
problemCount(problemCount_),
|
||||
ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)),
|
||||
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)),
|
||||
layoutA(layoutA_),
|
||||
ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)),
|
||||
layoutB(layoutB_),
|
||||
ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)),
|
||||
layoutScale(layoutScale_),
|
||||
ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)),
|
||||
layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)),
|
||||
layoutD(layoutD_),
|
||||
ptrWorkspace(ptrWorkspace_),
|
||||
combiner(combiner_)
|
||||
{}
|
||||
};
|
||||
|
||||
// Methods
|
||||
CATLASS_DEVICE
|
||||
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace()
|
||||
{
|
||||
Arch::FlagID flagId = 0;
|
||||
for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) {
|
||||
flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++);
|
||||
flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++);
|
||||
aicWaitFuncList[stageId] = {this, stageId};
|
||||
aicSetFuncList[stageId] = {this, stageId};
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t CORE_TYPE = g_coreType>
|
||||
CATLASS_DEVICE void operator()(Params const ¶ms);
|
||||
|
||||
template <>
|
||||
CATLASS_DEVICE void operator()<AscendC::AIC>(Params const ¶ms)
|
||||
{
|
||||
BlockScheduler blockScheduler;
|
||||
BlockMmad blockMmad(resource);
|
||||
|
||||
// Represent the full gm
|
||||
AscendC::GlobalTensor<ElementA> gmA;
|
||||
gmA.SetGlobalBuffer(params.ptrA);
|
||||
AscendC::GlobalTensor<ElementB> gmB;
|
||||
gmB.SetGlobalBuffer(params.ptrB);
|
||||
AscendC::GlobalTensor<ElementGroupList> groupList;
|
||||
groupList.SetGlobalBuffer(params.ptrGroupList);
|
||||
|
||||
uint32_t coreIdx = AscendC::GetBlockIdx();
|
||||
uint32_t coreNum = AscendC::GetBlockNum();
|
||||
int64_t gmGroupOffsetA = 0;
|
||||
int64_t gmGroupOffsetB = 0;
|
||||
|
||||
AscendC::GlobalTensor<ElementC> gmC;
|
||||
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
|
||||
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
|
||||
|
||||
uint32_t stageId = 0;
|
||||
uint32_t stageUsed = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
|
||||
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
|
||||
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
|
||||
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
|
||||
|
||||
LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK());
|
||||
LayoutB layoutB = params.layoutB;
|
||||
|
||||
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
|
||||
// Determine the starting loopIdx of the current core under the current
|
||||
// groupIdx
|
||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||
// Loop through the matmul of each groupIdx
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
// Compute block location
|
||||
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
|
||||
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
|
||||
|
||||
Callback callbackBeforeFixpipe{};
|
||||
if (stageUsed == WORKSPACE_STAGES) {
|
||||
callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]);
|
||||
} else {
|
||||
++stageUsed;
|
||||
}
|
||||
Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
|
||||
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
|
||||
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
|
||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||
int64_t gmOffsetB = layoutB.GetOffset(offsetB);
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
|
||||
// Compute block-scoped matrix multiply-add
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
|
||||
gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe);
|
||||
} else {
|
||||
callbackBeforeFixpipe();
|
||||
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
|
||||
gmC[gmOffsetC], layoutC, actualBlockShape);
|
||||
callbackAfterFixpipe();
|
||||
}
|
||||
|
||||
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
|
||||
}
|
||||
|
||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
|
||||
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
|
||||
blockMmad.SynchronizeBlock();
|
||||
}
|
||||
|
||||
while (stageUsed > 0) {
|
||||
uint32_t aivComputeStageId =
|
||||
(stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed);
|
||||
Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]);
|
||||
--stageUsed;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CATLASS_DEVICE void operator()<AscendC::AIV>(Params const ¶ms)
|
||||
{
|
||||
auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> *)params.combiner;
|
||||
{
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
if (get_subblockid() == 0) {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID);
|
||||
}
|
||||
}
|
||||
BlockScheduler blockScheduler;
|
||||
BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo());
|
||||
|
||||
uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum();
|
||||
uint32_t coreNum = AscendC::GetBlockNum();
|
||||
int64_t gmGroupOffsetScale = 0;
|
||||
int64_t gmGroupOffsetPerTokenScale = 0;
|
||||
int64_t gmGroupOffsetD = 0;
|
||||
AscendC::GlobalTensor<ElementGroupList> groupList;
|
||||
groupList.SetGlobalBuffer(params.ptrGroupList);
|
||||
|
||||
AscendC::GlobalTensor<ElementC> gmC;
|
||||
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
|
||||
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
|
||||
|
||||
uint32_t stageId = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
|
||||
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
|
||||
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
|
||||
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
|
||||
|
||||
LayoutScale layoutScale = params.layoutScale;
|
||||
LayoutPerTokenScale layoutPerTokenScale =
|
||||
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
|
||||
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());
|
||||
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
params.ptrD + gmGroupOffsetD,
|
||||
layoutD};
|
||||
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
|
||||
blockEpilogue.UpdateParams(epilogueParams);
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
|
||||
GemmCoord blockShapeMNK = L1TileShape::ToCoord();
|
||||
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
|
||||
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
|
||||
GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx);
|
||||
GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK);
|
||||
|
||||
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
|
||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||
auto gmBlockC = gmC[gmOffsetC];
|
||||
auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN());
|
||||
|
||||
Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]);
|
||||
blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC,
|
||||
layoutBlockC);
|
||||
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]);
|
||||
|
||||
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
|
||||
}
|
||||
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
|
||||
gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
}
|
||||
|
||||
icache_preload(4);
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
if (get_subblockid() == 0) {
|
||||
resource.pipe.Init();
|
||||
combiner->TPipeSet(&resource.pipe);
|
||||
combiner->AllToAllSend();
|
||||
combiner->TPipeSet(nullptr);
|
||||
resource.pipe.Destroy();
|
||||
} else {
|
||||
resource.pipe.Init();
|
||||
combiner->TPipeSet(&resource.pipe);
|
||||
combiner->ReducePermute();
|
||||
combiner->TPipeSet(nullptr);
|
||||
resource.pipe.Destroy();
|
||||
}
|
||||
} else {
|
||||
resource.pipe.Init();
|
||||
combiner->TPipeSet(&resource.pipe);
|
||||
combiner->Process();
|
||||
combiner->TPipeSet(nullptr);
|
||||
resource.pipe.Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct AicWaitFunc;
|
||||
friend struct AicSetFunc;
|
||||
|
||||
struct AicWaitFunc {
|
||||
using MatmulKernel =
|
||||
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
|
||||
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
||||
|
||||
CATLASS_DEVICE
|
||||
AicWaitFunc() = default;
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()() const
|
||||
{
|
||||
Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]);
|
||||
}
|
||||
|
||||
MatmulKernel *ptr{nullptr};
|
||||
uint32_t stageId;
|
||||
};
|
||||
|
||||
struct AicSetFunc {
|
||||
using MatmulKernel =
|
||||
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
|
||||
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
||||
|
||||
CATLASS_DEVICE
|
||||
AicSetFunc() = default;
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()() const
|
||||
{
|
||||
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]);
|
||||
}
|
||||
|
||||
MatmulKernel *ptr{nullptr};
|
||||
uint32_t stageId;
|
||||
};
|
||||
|
||||
Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES];
|
||||
Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES];
|
||||
|
||||
AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES];
|
||||
AicSetFunc aicSetFuncList[WORKSPACE_STAGES];
|
||||
Arch::Resource<ArchTag> resource;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Kernel
|
||||
|
||||
#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,814 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H
|
||||
#define CAM_MOE_DISTRIBUTE_COMBINE_H
|
||||
#define OPT_RANK_OFFSET 512
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../../dispatch_gmm_combine_decode_base.h"
|
||||
#include "../../dispatch_gmm_combine_decode_tiling.h"
|
||||
|
||||
namespace MoeDistributeCombineImpl {
|
||||
constexpr uint8_t BUFFER_NUM = 2; // multi-buf
|
||||
constexpr uint32_t STATE_OFFSET = 512;
|
||||
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
|
||||
constexpr uint32_t RANK_SIZE_ON_WIN_512 = 512 * 1024;
|
||||
constexpr uint32_t RANK_SIZE_ON_WIN_256 = 256 * 1024;
|
||||
constexpr uint32_t TP_RANK_SIZE_ON_WIN = 0;
|
||||
constexpr uint32_t UB_ALIGN = 32;
|
||||
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||
constexpr uint8_t EP_DOMAIN = 0;
|
||||
constexpr uint8_t TP_DOMAIN = 1;
|
||||
constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024;
|
||||
constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024;
|
||||
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
|
||||
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
|
||||
|
||||
template <AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFunc()
|
||||
{
|
||||
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
struct CombineCalcInfo {
|
||||
uint64_t expertPerSizeOnWin_;
|
||||
uint32_t epRankId_;
|
||||
uint32_t epWorldSize_;
|
||||
uint32_t moeExpertPerRankNum_;
|
||||
uint32_t sharedExpertRankNum_;
|
||||
uint32_t axisH_;
|
||||
uint32_t moeSendNum_;
|
||||
bool isShardExpert_;
|
||||
GM_ADDR epSendCount_;
|
||||
__gm__ HcclOpResParam *epWinContext_;
|
||||
uint64_t winDataSizeOffset_;
|
||||
};
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
class CamMoeDistributeCombine
|
||||
{
|
||||
public:
|
||||
__aicore__ inline CamMoeDistributeCombine(){};
|
||||
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
|
||||
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
||||
const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
__aicore__ inline void AllToAllSend();
|
||||
__aicore__ inline void ReducePermute();
|
||||
|
||||
__aicore__ inline CombineCalcInfo &GetCalcInfo()
|
||||
{
|
||||
return calcInfo_;
|
||||
}
|
||||
|
||||
__aicore__ inline void TPipeSet(AscendC::TPipe *pipe)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void InitStatusTargetSum();
|
||||
__aicore__ inline void AlltoAllBuffInit();
|
||||
__aicore__ inline void ReduceScatterTrans();
|
||||
__aicore__ inline void SetWaitTpStatusAndDisPatch();
|
||||
__aicore__ inline void CustomAdd(LocalTensor<ExpandXType> &dst, LocalTensor<ExpandXType> &src0,
|
||||
LocalTensor<ExpandXType> &src1, uint32_t dataCnt);
|
||||
__aicore__ inline void ExpertAlltoAllDispatchInnerCopyAdd(uint32_t tokenNumLoop, uint32_t srcStartTokenIdx,
|
||||
uint32_t ep, uint32_t expertIdx);
|
||||
__aicore__ inline void ExpertAlltoAllDispatchCopyAdd();
|
||||
__aicore__ inline void LocalWindowCopy();
|
||||
__aicore__ inline void BuffInit();
|
||||
__aicore__ inline void SplitCoreCal();
|
||||
__aicore__ inline void SetStatus();
|
||||
__aicore__ inline void WaitDispatch();
|
||||
__aicore__ GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t domain, const uint8_t expertLocalId = 0U)
|
||||
{
|
||||
if (domain == EP_DOMAIN) {
|
||||
return (GM_ADDR)((epRankId_ == rankId)
|
||||
? epWinContext_->localWindowsIn
|
||||
: ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsIn) +
|
||||
winDataSizeOffset_ + expertLocalId * expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
|
||||
} else {
|
||||
return (GM_ADDR)((tpRankId_ == rankId)
|
||||
? tpWinContext_->localWindowsIn
|
||||
: ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsIn) +
|
||||
winDataSizeOffset_ + rankId * OPT_RANK_OFFSET;
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ GM_ADDR GetWinStateAddrByRankId(const int32_t rankId, const uint8_t domain)
|
||||
{
|
||||
if (domain == EP_DOMAIN) {
|
||||
return (GM_ADDR)((epRankId_ == rankId)
|
||||
? epWinContext_->localWindowsExp
|
||||
: ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsExp) +
|
||||
dataState_ * WIN_STATE_OFFSET;
|
||||
} else {
|
||||
return (GM_ADDR)((tpRankId_ == rankId)
|
||||
? tpWinContext_->localWindowsExp
|
||||
: ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr))
|
||||
->windowsExp) +
|
||||
dataState_ * WIN_STATE_OFFSET;
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t MIN(uint32_t x, uint32_t y)
|
||||
{
|
||||
return (x < y) ? x : y;
|
||||
}
|
||||
|
||||
__aicore__ static void DoCombineRecv(void *ptr)
|
||||
{
|
||||
auto *combiner = (CamMoeDistributeCombine<TemplateMC2TypeFunc> *)ptr;
|
||||
combiner->ReducePermute();
|
||||
}
|
||||
|
||||
TPipe *tpipe_{nullptr};
|
||||
GlobalTensor<ExpandXType> expandXGM_;
|
||||
GlobalTensor<ExpandIdxType> expertIdsGM_;
|
||||
GlobalTensor<ExpandIdxType> expandIdxGM_;
|
||||
GlobalTensor<ExpandIdxType> epSendCountGM_;
|
||||
GlobalTensor<ExpandIdxType> tpSendCountGM_;
|
||||
GlobalTensor<float> expandScalesGM_;
|
||||
GlobalTensor<ExpandXType> expandOutGlobal_;
|
||||
GlobalTensor<ExpandXType> rankWindow_;
|
||||
GlobalTensor<int32_t> rankStates_;
|
||||
GlobalTensor<float> epStatusSpaceGlobalTensor_;
|
||||
GlobalTensor<float> tpStatusSpaceGlobalTensor_;
|
||||
GlobalTensor<ExpandXType> tpRankWindow_;
|
||||
GlobalTensor<ExpandXType> rowTmpGlobal_;
|
||||
GM_ADDR workspaceGM_;
|
||||
GM_ADDR epWindowGM_;
|
||||
GM_ADDR epStatusSpaceGm_;
|
||||
GM_ADDR tpWindowGM_;
|
||||
GM_ADDR tpStatusSpaceGm_;
|
||||
GM_ADDR stateGM_;
|
||||
|
||||
LocalTensor<ExpandXType> winTpSendCountTensor_;
|
||||
LocalTensor<ExpandXType> gmTpSendCountTensor_;
|
||||
LocalTensor<ExpandXType> outTensor_;
|
||||
LocalTensor<float> winTpSendCountFloatTensor_;
|
||||
LocalTensor<float> gmTpSendCountFloatTensor_;
|
||||
LocalTensor<ExpandIdxType> epSendCountLocal_;
|
||||
|
||||
CombineCalcInfo calcInfo_;
|
||||
uint32_t axisBS_{0};
|
||||
uint32_t axisMaxBs_{0};
|
||||
uint32_t axisH_{0};
|
||||
uint32_t axisK_{0};
|
||||
uint32_t aivNum_{0};
|
||||
uint32_t epWorldSize_{0};
|
||||
uint32_t tpWorldSize_{0};
|
||||
uint32_t epRankId_{0};
|
||||
uint32_t tpRankId_{0};
|
||||
uint32_t coreIdx_{0}; // aiv id
|
||||
uint32_t sharedExpertRankNum_{0};
|
||||
uint32_t moeExpertNum_{0};
|
||||
uint32_t moeExpertPerRankNum_{0};
|
||||
uint32_t moeSendNum_{0}; // moeExpertPerRankNum_ * epWorldSize_
|
||||
uint32_t tpScatterNum_{0};
|
||||
uint32_t firstTpTokenEndIdx_{0};
|
||||
uint32_t firstTpTokenEndOffset_{0};
|
||||
uint32_t endTok_{0};
|
||||
__gm__ HcclOpResParam *epWinContext_{nullptr};
|
||||
__gm__ HcclOpResParam *tpWinContext_{nullptr};
|
||||
uint32_t epDataOffsetOnWin_{0};
|
||||
uint32_t tpDataOffsetOnWin_{0};
|
||||
uint32_t epStateOffsetOnWin_{0};
|
||||
uint32_t tpStateOffsetOnWin_{0};
|
||||
uint32_t axisHFloatSize_{0};
|
||||
uint32_t axisHExpandXTypeSize_{0};
|
||||
uint32_t bsKNum_{0};
|
||||
uint32_t startRankId_{0};
|
||||
uint32_t endRankId_{0};
|
||||
uint32_t sendRankNum_{0};
|
||||
uint32_t ubSize_{0};
|
||||
uint32_t dataState_{0};
|
||||
uint32_t stateOffset_{0};
|
||||
uint64_t winDataSizeOffset_{0};
|
||||
uint64_t expertPerSizeOnWin_{0};
|
||||
uint64_t totalWinSize_{0};
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> moeQueue_;
|
||||
TQue<QuePosition::VECIN, 1> moeSumQueue_;
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> gmTpSendCountQueue_;
|
||||
TQue<QuePosition::VECIN, 1> gmTpSendCountInQueue_;
|
||||
TQue<QuePosition::VECIN, 1> winTpSendCountInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> xOutQueue_;
|
||||
TBuf<> readStateBuf_;
|
||||
TBuf<> expertIdsBuf_;
|
||||
TBuf<> expandScalesBuf_;
|
||||
TBuf<> rowTmpFloatBuf_;
|
||||
TBuf<> sumFloatBuf_;
|
||||
TBuf<> mulBuf_;
|
||||
TBuf<> sendCountBuf_;
|
||||
TBuf<> indexCountsBuf_;
|
||||
TBuf<> winTpSendCountFloatBuf_;
|
||||
TBuf<> gmTpSendCountFloatBuf_;
|
||||
TBuf<> tokenBuf_;
|
||||
TBuf<> statusBuf_;
|
||||
TBuf<> gatherMaskOutBuf_; // gather mask output buf
|
||||
TBuf<> gatherTmpBuf_;
|
||||
TBuf<> statusSumOutBuf_;
|
||||
float sumTarget_{0.0};
|
||||
int32_t epStateValue_;
|
||||
bool isShardExpert_{false};
|
||||
};
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
||||
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales,
|
||||
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
coreIdx_ = GetBlockIdx();
|
||||
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
epWinContext_ = (__gm__ HcclOpResParam *)contextGM0;
|
||||
GlobalTensor<int32_t> selfDataStatusTensor;
|
||||
GM_ADDR statusDataSpaceGm = (GM_ADDR)epWinContext_->localWindowsExp;
|
||||
selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET));
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN);
|
||||
if (dataState_ == 0) {
|
||||
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1;
|
||||
} else {
|
||||
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 0;
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
pipe_barrier(PIPE_ALL);
|
||||
|
||||
workspaceGM_ = workspaceGM;
|
||||
expandXGM_.SetGlobalBuffer((__gm__ ExpandXType *)expandX);
|
||||
expertIdsGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds);
|
||||
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
|
||||
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
|
||||
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
|
||||
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
|
||||
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
aivNum_ = get_block_num();
|
||||
} else {
|
||||
aivNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum;
|
||||
}
|
||||
ubSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalUbSize;
|
||||
sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
moeExpertPerRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
axisMaxBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_;
|
||||
moeSendNum_ = epWorldSize_ * moeExpertPerRankNum_;
|
||||
tpWorldSize_ = 1;
|
||||
tpRankId_ = 0;
|
||||
totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize;
|
||||
stateOffset_ = (moeSendNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET;
|
||||
expertPerSizeOnWin_ =
|
||||
static_cast<uint64_t>(axisMaxBs_) * static_cast<uint64_t>(axisH_) * static_cast<uint64_t>(sizeof(ExpandXType));
|
||||
winDataSizeOffset_ = static_cast<uint64_t>(dataState_) * static_cast<uint64_t>(moeSendNum_) * expertPerSizeOnWin_;
|
||||
epWindowGM_ = GetWinAddrByRankId(epRankId_, EP_DOMAIN);
|
||||
epStatusSpaceGm_ = GetWinStateAddrByRankId(epRankId_, EP_DOMAIN);
|
||||
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)epStatusSpaceGm_);
|
||||
epDataOffsetOnWin_ = epRankId_ * moeExpertPerRankNum_ * static_cast<uint32_t>(expertPerSizeOnWin_);
|
||||
epStateOffsetOnWin_ = epRankId_ * stateOffset_;
|
||||
isShardExpert_ = (epRankId_ < sharedExpertRankNum_);
|
||||
axisHFloatSize_ = axisH_ * sizeof(float);
|
||||
axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType);
|
||||
bsKNum_ = axisBS_ * axisK_;
|
||||
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
tpSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)tpSendCount);
|
||||
tpWindowGM_ = GetWinAddrByRankId(tpRankId_, TP_DOMAIN);
|
||||
tpStatusSpaceGm_ = GetWinStateAddrByRankId(tpRankId_, TP_DOMAIN);
|
||||
tpStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)tpStatusSpaceGm_);
|
||||
tpDataOffsetOnWin_ = tpRankId_ * TP_RANK_SIZE_ON_WIN;
|
||||
tpStateOffsetOnWin_ = tpRankId_ * stateOffset_;
|
||||
uint32_t tpScatterRankWinOffset = (tpRankId_ == 0) ? TP_RANK_SIZE_ON_WIN : 0;
|
||||
GM_ADDR rankGM = tpWindowGM_ + tpScatterRankWinOffset;
|
||||
tpRankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
|
||||
}
|
||||
|
||||
InitStatusTargetSum();
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
coreIdx_ = get_block_idx();
|
||||
}
|
||||
SplitCoreCal();
|
||||
|
||||
calcInfo_.epRankId_ = epRankId_;
|
||||
calcInfo_.epWorldSize_ = epWorldSize_;
|
||||
calcInfo_.expertPerSizeOnWin_ = expertPerSizeOnWin_;
|
||||
calcInfo_.moeExpertPerRankNum_ = moeExpertPerRankNum_;
|
||||
calcInfo_.sharedExpertRankNum_ = sharedExpertRankNum_;
|
||||
calcInfo_.axisH_ = axisH_;
|
||||
calcInfo_.moeSendNum_ = moeSendNum_;
|
||||
calcInfo_.isShardExpert_ = isShardExpert_;
|
||||
calcInfo_.epSendCount_ = epSendCount;
|
||||
calcInfo_.epWinContext_ = epWinContext_;
|
||||
calcInfo_.winDataSizeOffset_ = winDataSizeOffset_;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::InitStatusTargetSum()
|
||||
{
|
||||
// ep state
|
||||
GlobalTensor<int32_t> selfStatusTensor;
|
||||
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(epStatusSpaceGm_ + SELF_STATE_OFFSET));
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfStatusTensor[coreIdx_ * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN);
|
||||
if (state == 0) {
|
||||
sumTarget_ = static_cast<float>(1.0);
|
||||
selfStatusTensor(coreIdx_ * UB_ALIGN) = 0x3F800000; // 1.0f
|
||||
epStateValue_ = 0x3F800000; // 1.0f
|
||||
} else {
|
||||
sumTarget_ = static_cast<float>(0.0);
|
||||
selfStatusTensor(coreIdx_ * UB_ALIGN) = 0;
|
||||
epStateValue_ = 0;
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
selfStatusTensor[coreIdx_ * UB_ALIGN]);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::BuffInit()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
tpipe_->InitBuffer(readStateBuf_, UB_ALIGN);
|
||||
uint32_t sendNumAlign = Ceil(moeSendNum_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(sendCountBuf_, sendNumAlign);
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
tpipe_->InitBuffer(winTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
tpipe_->InitBuffer(gmTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
if constexpr (AscendC::IsSameType<ExpandXType, bfloat16_t>::value) {
|
||||
tpipe_->InitBuffer(winTpSendCountFloatBuf_, axisHFloatSize_);
|
||||
tpipe_->InitBuffer(gmTpSendCountFloatBuf_, axisHFloatSize_);
|
||||
winTpSendCountFloatTensor_ = winTpSendCountFloatBuf_.Get<float>();
|
||||
gmTpSendCountFloatTensor_ = gmTpSendCountFloatBuf_.Get<float>();
|
||||
}
|
||||
} else {
|
||||
tpipe_->InitBuffer(gmTpSendCountQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
}
|
||||
epSendCountLocal_ = sendCountBuf_.Get<int32_t>();
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuffInit()
|
||||
{
|
||||
tpipe_->Reset();
|
||||
uint32_t bsMulTopkSizeAligned = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(readStateBuf_, UB_ALIGN);
|
||||
tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN);
|
||||
tpipe_->InitBuffer(expertIdsBuf_, bsMulTopkSizeAligned);
|
||||
tpipe_->InitBuffer(expandScalesBuf_, bsMulTopkSizeAligned);
|
||||
tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType));
|
||||
tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_);
|
||||
tpipe_->InitBuffer(mulBuf_, axisHFloatSize_);
|
||||
tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_);
|
||||
tpipe_->InitBuffer(indexCountsBuf_, bsMulTopkSizeAligned);
|
||||
tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
|
||||
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
|
||||
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SplitCoreCal()
|
||||
{
|
||||
sendRankNum_ = epWorldSize_ / aivNum_;
|
||||
uint32_t remainderRankNum = epWorldSize_ % aivNum_;
|
||||
startRankId_ = sendRankNum_ * coreIdx_;
|
||||
if (coreIdx_ < remainderRankNum) {
|
||||
sendRankNum_++;
|
||||
startRankId_ += coreIdx_;
|
||||
} else {
|
||||
startRankId_ += remainderRankNum;
|
||||
}
|
||||
endRankId_ = startRankId_ + sendRankNum_;
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ReduceScatterTrans()
|
||||
{
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(tpSendCountGM_[tpRankId_]);
|
||||
__asm__ __volatile__("");
|
||||
uint32_t offset = tpSendCountGM_.GetValue(tpRankId_) * axisH_;
|
||||
GlobalTensor<ExpandXType> dataCopyInGM = expandXGM_[offset];
|
||||
GM_ADDR rankGM = GetWinAddrByRankId(1 - tpRankId_, TP_DOMAIN) + tpDataOffsetOnWin_;
|
||||
rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
|
||||
uint32_t copyStartIdx = 0;
|
||||
if (startRankId_ > 0) {
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
epSendCountGM_[epWorldSize_ + startRankId_ - 1]);
|
||||
__asm__ __volatile__("");
|
||||
copyStartIdx = epSendCountGM_.GetValue(epWorldSize_ + startRankId_ - 1);
|
||||
}
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
|
||||
epSendCountGM_[epWorldSize_ + endRankId_ - 1]);
|
||||
__asm__ __volatile__("");
|
||||
uint32_t copyEndIdx = epSendCountGM_.GetValue(epWorldSize_ + endRankId_ - 1);
|
||||
LocalTensor<ExpandXType> tmpUb;
|
||||
for (uint32_t tokenNumIdx = copyStartIdx; tokenNumIdx < copyEndIdx; tokenNumIdx++) {
|
||||
tmpUb = moeQueue_.AllocTensor<ExpandXType>();
|
||||
DataCopy(tmpUb, dataCopyInGM[tokenNumIdx * axisH_], axisH_);
|
||||
moeQueue_.EnQue(tmpUb);
|
||||
tmpUb = moeQueue_.DeQue<ExpandXType>();
|
||||
DataCopy(rankWindow_[tokenNumIdx * axisH_], tmpUb, axisH_);
|
||||
moeQueue_.FreeTensor<ExpandXType>(tmpUb);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SetWaitTpStatusAndDisPatch()
|
||||
{
|
||||
pipe_barrier(PIPE_ALL);
|
||||
if (startRankId_ >= epWorldSize_) {
|
||||
return;
|
||||
}
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
uint32_t tpToRankId = 1 - tpRankId_;
|
||||
pipe_barrier(PIPE_ALL);
|
||||
LocalTensor<float> statusFlagUb = readStateBuf_.Get<float>();
|
||||
statusFlagUb(0) = sumTarget_;
|
||||
SyncFunc<AscendC::HardEvent::S_MTE3>();
|
||||
GlobalTensor<float> tpWindowInstatusFp32Tensor_;
|
||||
stateGM_ = GetWinStateAddrByRankId(tpToRankId, TP_DOMAIN) + coreIdx_ * stateOffset_;
|
||||
tpWindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)stateGM_);
|
||||
DataCopy<float>(tpWindowInstatusFp32Tensor_, statusFlagUb, 8UL);
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
LocalTensor<float> statusFp32Tensor_ = readStateBuf_.Get<float>();
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
uint32_t statusRankOffset = coreIdx_ * stateOffset_ / sizeof(float);
|
||||
while (sumOfFlag != sumTarget_) {
|
||||
DataCopy<float>(statusFp32Tensor_, tpStatusSpaceGlobalTensor_[statusRankOffset], 8);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
sumOfFlag = statusFp32Tensor_.GetValue(0);
|
||||
SyncFunc<AscendC::HardEvent::S_MTE2>();
|
||||
}
|
||||
}
|
||||
ExpertAlltoAllDispatchCopyAdd();
|
||||
SyncFunc<AscendC::HardEvent::MTE3_S>();
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ExpertAlltoAllDispatchCopyAdd()
|
||||
{
|
||||
if (startRankId_ >= epWorldSize_) {
|
||||
return;
|
||||
}
|
||||
uint32_t curRankExpertNum = 0;
|
||||
DataCopyExtParams epSendCntParams;
|
||||
if (isShardExpert_) {
|
||||
curRankExpertNum = 1;
|
||||
epSendCntParams = {1U, static_cast<uint32_t>(epWorldSize_ * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
} else {
|
||||
curRankExpertNum = moeExpertPerRankNum_;
|
||||
epSendCntParams = {1U, static_cast<uint32_t>(moeSendNum_ * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
}
|
||||
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(epSendCountLocal_, epSendCountGM_, epSendCntParams, copyPadParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
uint32_t preCount = 0;
|
||||
uint32_t startTokenIdx = 0;
|
||||
uint32_t curTokenNum = 0;
|
||||
|
||||
for (uint32_t expertIdx = 0U; expertIdx < curRankExpertNum; expertIdx++) {
|
||||
uint32_t sendEpCount = endRankId_ - startRankId_;
|
||||
for (uint32_t i = 0; i < sendEpCount; ++i) {
|
||||
uint32_t ep = startRankId_ + (i + epRankId_) % sendEpCount;
|
||||
if ((ep > 0) || (expertIdx > 0U)) {
|
||||
preCount = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep - 1);
|
||||
} else {
|
||||
preCount = 0;
|
||||
}
|
||||
curTokenNum = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep) - preCount;
|
||||
if (curTokenNum == 0) {
|
||||
continue;
|
||||
}
|
||||
startTokenIdx = preCount * axisH_;
|
||||
ExpertAlltoAllDispatchInnerCopyAdd(curTokenNum, startTokenIdx, ep, expertIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ExpertAlltoAllDispatchInnerCopyAdd(
|
||||
uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, uint32_t ep, uint32_t expertIdx)
|
||||
{
|
||||
GM_ADDR rankGM = GetWinAddrByRankId(ep, EP_DOMAIN, expertIdx) + epDataOffsetOnWin_;
|
||||
if ((isShardExpert_) && (ep < sharedExpertRankNum_)) {
|
||||
rankGM = GetWinAddrByRankId(epRankId_, EP_DOMAIN, expertIdx) + ep * moeExpertPerRankNum_ * expertPerSizeOnWin_;
|
||||
}
|
||||
rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
|
||||
uint32_t dataCnt = axisH_;
|
||||
for (uint32_t loopIdx = 0; loopIdx < tokenNumLoop; loopIdx++) {
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
gmTpSendCountTensor_ = gmTpSendCountInQueue_.AllocTensor<ExpandXType>();
|
||||
DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt);
|
||||
gmTpSendCountInQueue_.EnQue(gmTpSendCountTensor_);
|
||||
|
||||
winTpSendCountTensor_ = winTpSendCountInQueue_.AllocTensor<ExpandXType>();
|
||||
DataCopy(winTpSendCountTensor_, tpRankWindow_[srcStartTokenIdx], dataCnt);
|
||||
winTpSendCountInQueue_.EnQue(winTpSendCountTensor_);
|
||||
|
||||
gmTpSendCountTensor_ = gmTpSendCountInQueue_.DeQue<ExpandXType>();
|
||||
winTpSendCountTensor_ = winTpSendCountInQueue_.DeQue<ExpandXType>();
|
||||
outTensor_ = xOutQueue_.AllocTensor<ExpandXType>();
|
||||
|
||||
CustomAdd(outTensor_, winTpSendCountTensor_, gmTpSendCountTensor_, dataCnt);
|
||||
gmTpSendCountInQueue_.FreeTensor<ExpandXType>(gmTpSendCountTensor_);
|
||||
winTpSendCountInQueue_.FreeTensor<ExpandXType>(winTpSendCountTensor_);
|
||||
xOutQueue_.EnQue(outTensor_);
|
||||
|
||||
outTensor_ = xOutQueue_.DeQue<ExpandXType>();
|
||||
DataCopy(rankWindow_[loopIdx * dataCnt], outTensor_, dataCnt);
|
||||
xOutQueue_.FreeTensor<ExpandXType>(outTensor_);
|
||||
} else {
|
||||
gmTpSendCountTensor_ = gmTpSendCountQueue_.AllocTensor<ExpandXType>();
|
||||
DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt);
|
||||
ExpandXType val = expandXGM_[srcStartTokenIdx].GetValue(0);
|
||||
gmTpSendCountQueue_.EnQue(gmTpSendCountTensor_);
|
||||
gmTpSendCountTensor_ = gmTpSendCountQueue_.DeQue<ExpandXType>();
|
||||
DataCopy(rankWindow_[loopIdx * dataCnt], gmTpSendCountTensor_, dataCnt);
|
||||
gmTpSendCountQueue_.FreeTensor<ExpandXType>(gmTpSendCountTensor_);
|
||||
}
|
||||
srcStartTokenIdx += dataCnt;
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::CustomAdd(LocalTensor<ExpandXType> &dst,
|
||||
LocalTensor<ExpandXType> &src0,
|
||||
LocalTensor<ExpandXType> &src1,
|
||||
uint32_t dataCnt)
|
||||
{
|
||||
if constexpr (AscendC::IsSameType<ExpandXType, bfloat16_t>::value) {
|
||||
Cast(winTpSendCountFloatTensor_, src0, RoundMode::CAST_NONE, dataCnt);
|
||||
Cast(gmTpSendCountFloatTensor_, src1, RoundMode::CAST_NONE, dataCnt);
|
||||
pipe_barrier(PIPE_V);
|
||||
Add(winTpSendCountFloatTensor_, winTpSendCountFloatTensor_, gmTpSendCountFloatTensor_, dataCnt);
|
||||
pipe_barrier(PIPE_V);
|
||||
Cast(dst, winTpSendCountFloatTensor_, RoundMode::CAST_ROUND, dataCnt);
|
||||
} else {
|
||||
Add(dst, src0, src1, dataCnt);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SetStatus()
|
||||
{
|
||||
pipe_barrier(PIPE_ALL);
|
||||
if (startRankId_ >= epWorldSize_) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocalTensor<int32_t> statusFlagUb = readStateBuf_.Get<int32_t>();
|
||||
statusFlagUb.SetValue(0, epStateValue_);
|
||||
SyncFunc<AscendC::HardEvent::S_MTE3>();
|
||||
|
||||
for (uint32_t epIdx = startRankId_; epIdx < endRankId_; epIdx++) {
|
||||
stateGM_ = GetWinStateAddrByRankId(epIdx, EP_DOMAIN) + epStateOffsetOnWin_;
|
||||
rankStates_.SetGlobalBuffer((__gm__ int32_t *)stateGM_);
|
||||
DataCopy(rankStates_, statusFlagUb, 8);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatch()
|
||||
{
|
||||
if (startRankId_ < epWorldSize_) {
|
||||
LocalTensor<float> statusTensor = statusBuf_.Get<float>();
|
||||
LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf_.Get<float>();
|
||||
LocalTensor<uint32_t> gatherTmpTensor = gatherTmpBuf_.Get<uint32_t>();
|
||||
LocalTensor<float> statusSumOutTensor = statusSumOutBuf_.Get<float>();
|
||||
PipeBarrier<PIPE_ALL>();
|
||||
|
||||
gatherTmpTensor.SetValue(0, 1);
|
||||
uint32_t mask = 1; // gatherMask + sum
|
||||
uint64_t rsvdCnt = 0;
|
||||
DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
|
||||
static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks
|
||||
float sumOfFlag = static_cast<float>(-1.0);
|
||||
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
|
||||
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
|
||||
SumParams sumParams{1, sendRankNum_, sendRankNum_};
|
||||
SyncFunc<AscendC::HardEvent::S_V>();
|
||||
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
|
||||
DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
|
||||
intriParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
|
||||
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
sumOfFlag = statusSumOutTensor.GetValue(0);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
} else {
|
||||
SyncAll<true>();
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
|
||||
{
|
||||
uint32_t beginIndex = 0;
|
||||
uint32_t endIndex = 0;
|
||||
uint32_t processLen = 0;
|
||||
uint32_t tokenOffset = 0;
|
||||
if (axisBS_ < aivNum_) {
|
||||
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_
|
||||
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) {
|
||||
return;
|
||||
}
|
||||
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
|
||||
processLen = ((axisH_ / UB_ALIGN) / aivNumPerToken) * UB_ALIGN;
|
||||
tokenOffset = processLen * (coreIdx_ % aivNumPerToken);
|
||||
if ((coreIdx_ % aivNumPerToken) == (aivNumPerToken - 1)) {
|
||||
processLen = axisH_ - ((aivNumPerToken - 1) * processLen);
|
||||
}
|
||||
beginIndex = tokenIndex;
|
||||
endIndex = beginIndex + 1U;
|
||||
} else {
|
||||
uint32_t tokenPerAivNum = axisBS_ / aivNum_;
|
||||
uint32_t remainderToken = axisBS_ % aivNum_;
|
||||
beginIndex = tokenPerAivNum * coreIdx_;
|
||||
if (coreIdx_ < remainderToken) {
|
||||
tokenPerAivNum++;
|
||||
beginIndex = tokenPerAivNum * coreIdx_;
|
||||
} else {
|
||||
beginIndex += remainderToken;
|
||||
}
|
||||
endIndex = beginIndex + tokenPerAivNum;
|
||||
processLen = axisH_;
|
||||
}
|
||||
LocalTensor<ExpandIdxType> expertIdsLocal = expertIdsBuf_.Get<ExpandIdxType>();
|
||||
LocalTensor<float> expandScalesLocal = expandScalesBuf_.Get<float>();
|
||||
|
||||
LocalTensor<float> rowTmpFloatLocal = rowTmpFloatBuf_.Get<float>();
|
||||
LocalTensor<float> mulBufLocal = mulBuf_.Get<float>();
|
||||
LocalTensor<float> sumFloatBufLocal = sumFloatBuf_.Get<float>();
|
||||
|
||||
LocalTensor<ExpandIdxType> indexCountsLocal = indexCountsBuf_.Get<ExpandIdxType>();
|
||||
const DataCopyExtParams bskParams = {1U, static_cast<uint32_t>(bsKNum_ * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<ExpandIdxType> copyPadParams{false, 0U, 0U, 0U};
|
||||
const DataCopyPadExtParams<float> copyPadFloatParams{false, 0U, 0U, 0U};
|
||||
|
||||
DataCopyPad(indexCountsLocal, expandIdxGM_, bskParams, copyPadParams);
|
||||
DataCopyPad(expertIdsLocal, expertIdsGM_, bskParams, copyPadParams);
|
||||
DataCopyPad(expandScalesLocal, expandScalesGM_, bskParams, copyPadFloatParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_S>();
|
||||
|
||||
for (uint32_t tokenIndex = beginIndex; tokenIndex < endIndex; tokenIndex++) {
|
||||
uint32_t index = tokenIndex * axisK_;
|
||||
SyncFunc<AscendC::HardEvent::MTE3_V>();
|
||||
Duplicate(sumFloatBufLocal, (float)0, axisH_);
|
||||
for (uint32_t i = 0; i < axisK_; i++) {
|
||||
int32_t moeExpert = expertIdsLocal.GetValue(index);
|
||||
if (moeExpert < 0) {
|
||||
index++;
|
||||
continue;
|
||||
}
|
||||
float scaleVal = expandScalesLocal.GetValue(index);
|
||||
GM_ADDR wAddr = (__gm__ uint8_t *)(epWindowGM_) +
|
||||
expertPerSizeOnWin_ * moeExpertPerRankNum_ * sharedExpertRankNum_ +
|
||||
expertPerSizeOnWin_ * moeExpert + indexCountsLocal.GetValue(index) * axisHExpandXTypeSize_ +
|
||||
tokenOffset * sizeof(ExpandXType);
|
||||
rowTmpGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr);
|
||||
ExpandXType val = rowTmpGlobal_.GetValue(0);
|
||||
LocalTensor<ExpandXType> tmpUb = moeSumQueue_.AllocTensor<ExpandXType>();
|
||||
DataCopy(tmpUb, rowTmpGlobal_, processLen);
|
||||
moeSumQueue_.EnQue(tmpUb);
|
||||
tmpUb = moeSumQueue_.DeQue<ExpandXType>();
|
||||
Cast(rowTmpFloatLocal, tmpUb, AscendC::RoundMode::CAST_NONE, processLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(mulBufLocal, rowTmpFloatLocal, scaleVal, processLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, mulBufLocal, processLen);
|
||||
index++;
|
||||
moeSumQueue_.FreeTensor<ExpandXType>(tmpUb);
|
||||
}
|
||||
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
|
||||
if (sharedExpertRankNum_ > 0U) {
|
||||
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
|
||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
|
||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
|
||||
epRankId_ * axisBS_ / sharedExpertRankNum_;
|
||||
__gm__ ExpandXType *shareAddr =
|
||||
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
|
||||
(tokenIndex - preCnt) * axisH_ + tokenOffset;
|
||||
GlobalTensor<ExpandXType> shareTokGlobal;
|
||||
shareTokGlobal.SetGlobalBuffer((__gm__ ExpandXType *)(shareAddr));
|
||||
SyncFunc<AscendC::HardEvent::V_MTE2>();
|
||||
DataCopy(rowTmpLocal, shareTokGlobal, processLen);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
Cast(rowTmpFloatLocal, rowTmpLocal, AscendC::RoundMode::CAST_NONE, processLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, rowTmpFloatLocal, processLen);
|
||||
}
|
||||
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
LocalTensor<ExpandXType> sumBufLocal = tokenBuf_.Get<ExpandXType>();
|
||||
Cast(sumBufLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, processLen);
|
||||
SyncFunc<AscendC::HardEvent::V_MTE3>();
|
||||
DataCopy(expandOutGlobal_[tokenIndex * axisH_ + tokenOffset], sumBufLocal, processLen);
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Process()
|
||||
{
|
||||
SyncAll<true>();
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
ReduceScatterTrans();
|
||||
}
|
||||
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
|
||||
BuffInit();
|
||||
SetWaitTpStatusAndDisPatch();
|
||||
}
|
||||
AlltoAllBuffInit();
|
||||
SetStatus();
|
||||
WaitDispatch();
|
||||
LocalWindowCopy();
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AllToAllSend()
|
||||
{
|
||||
if constexpr (IsNeedReduceScatter) {
|
||||
tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
|
||||
ReduceScatterTrans();
|
||||
}
|
||||
BuffInit();
|
||||
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
|
||||
SetWaitTpStatusAndDisPatch();
|
||||
AlltoAllBuffInit();
|
||||
}
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
} else {
|
||||
SyncAll<true>();
|
||||
}
|
||||
SetStatus();
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
|
||||
} else {
|
||||
SyncAll<true>();
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ReducePermute()
|
||||
{
|
||||
AlltoAllBuffInit();
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
|
||||
} else {
|
||||
SyncAll<true>();
|
||||
}
|
||||
|
||||
WaitDispatch();
|
||||
LocalWindowCopy();
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
|
||||
}
|
||||
}
|
||||
} // namespace MoeDistributeCombineImpl
|
||||
|
||||
#endif // CAM_MOE_DISTRIBUTE_COMBINE_IMPL_H
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
#define DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
||||
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
||||
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||
#define DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
|
||||
struct DispatchGmmCombineDecodeInfo {
|
||||
uint32_t epRankSize; // epRankSize
|
||||
uint32_t epRankId; // epRankId
|
||||
uint32_t moeExpertNum; // moe expert number
|
||||
uint32_t moeExpertNumPerRank; // moe expert number per rank
|
||||
uint32_t sharedExpertNum; // shared expert number
|
||||
uint32_t sharedExpertRankNum; // shared expert rank number
|
||||
uint32_t quantMode; // quant mode
|
||||
uint32_t globalBs; // globalBs = BS * worldSize
|
||||
uint32_t bs; // bs
|
||||
uint32_t k; // k
|
||||
uint32_t h; // h
|
||||
uint32_t aicNum; // aicNum
|
||||
uint32_t aivNum; // aivNum
|
||||
uint64_t totalUbSize;
|
||||
uint64_t totalWinSize;
|
||||
uint64_t gmm1HLen;
|
||||
};
|
||||
|
||||
struct DispatchGmmCombineDecodeTilingData {
|
||||
Mc2InitTiling mc2InitTiling;
|
||||
Mc2CcTiling mc2CcTiling;
|
||||
DispatchGmmCombineDecodeInfo disGmmDeqSwigluQuantGmmDeqComInfo;
|
||||
};
|
||||
|
||||
constexpr uint32_t GM_ALIGN_BYTE = 512;
|
||||
constexpr uint32_t CUSTOM_PRELOAD_STAGES = 1;
|
||||
constexpr uint32_t CUSTOM_L1_STAGES = 2;
|
||||
constexpr uint32_t CUSTOM_L0A_STAGES = 2;
|
||||
constexpr uint32_t CUSTOM_L0B_STAGES = 2;
|
||||
constexpr uint32_t CUSTOM_L0C_STAGES = 1;
|
||||
constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true;
|
||||
constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true;
|
||||
|
||||
constexpr uint32_t GMM1_L1M = 256;
|
||||
constexpr uint32_t GMM1_L1N = 128;
|
||||
constexpr uint32_t GMM1_L1K = 512;
|
||||
constexpr uint32_t GMM1_L0K = 128;
|
||||
constexpr uint32_t GMM1_EPIM = 64;
|
||||
constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3;
|
||||
constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0;
|
||||
|
||||
constexpr uint32_t GMM2_L1A_STAGES = 4;
|
||||
constexpr uint32_t GMM2_L1B_STAGES = 2;
|
||||
constexpr uint32_t GMM2_L0A_STAGES = 4;
|
||||
constexpr uint32_t GMM2_L0B_STAGES = 2;
|
||||
constexpr uint32_t GMM2_L1M = 128;
|
||||
constexpr uint32_t GMM2_L1N = 256;
|
||||
constexpr uint32_t GMM2_L1K = 512;
|
||||
constexpr uint32_t GMM2_L0K = 128;
|
||||
constexpr uint32_t GMM2_EPIM = 32;
|
||||
constexpr uint32_t GMM2_SWIZZLE_OFFSET = 3;
|
||||
constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
|
||||
|
||||
constexpr uint32_t WORKSPACE_STAGES = 4;
|
||||
|
||||
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
|
||||
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||
@@ -595,6 +595,64 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
const at::Tensor &x,
|
||||
const at::Tensor &expert_ids,
|
||||
const at::Tensor &gmm1_permuted_weight,
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
int64_t moe_expert_num,
|
||||
int64_t shared_expert_num,
|
||||
int64_t shared_expert_rank_num,
|
||||
int64_t quant_mode,
|
||||
int64_t global_bs)
|
||||
{
|
||||
auto x_shape = x.sizes();
|
||||
int bs = x_shape[0];
|
||||
int h = x_shape[1];
|
||||
|
||||
at::Tensor output = at::empty({bs, h}, x.options());
|
||||
|
||||
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
|
||||
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
|
||||
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
|
||||
|
||||
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char *group_ep_ptr = &group_ep_chrs[0];
|
||||
EXEC_NPU_CMD(
|
||||
// op api
|
||||
aclnnDispatchGmmCombineDecode,
|
||||
// input tensors
|
||||
x,
|
||||
expert_ids,
|
||||
gmm1_permuted_weight,
|
||||
gmm1_permuted_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
expert_smooth_scales,
|
||||
expert_scales,
|
||||
//input attrs
|
||||
group_ep_ptr,
|
||||
ep_rank_size,
|
||||
ep_rank_id,
|
||||
moe_expert_num,
|
||||
shared_expert_num,
|
||||
shared_expert_rank_num,
|
||||
quant_mode,
|
||||
global_bs,
|
||||
// output tensors
|
||||
output,
|
||||
ep_recv_count);
|
||||
return {output, ep_recv_count};
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
@@ -818,6 +876,19 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
|
||||
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
|
||||
|
||||
ops.def(
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||
" Tensor gmm1_permuted_weight_scale,"
|
||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
|
||||
" str group_ep='',"
|
||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||
" int quant_mode=0,"
|
||||
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
|
||||
);
|
||||
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);
|
||||
|
||||
ops.def(
|
||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
|
||||
" Tensor group_list, *,"
|
||||
|
||||
@@ -154,6 +154,37 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
|
||||
const at::Tensor &x,
|
||||
const at::Tensor &expert_ids,
|
||||
const at::Tensor &gmm1_permuted_weight,
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
int64_t moe_expert_num,
|
||||
int64_t shared_expert_num,
|
||||
int64_t shared_expert_rank_num,
|
||||
int64_t quant_mode,
|
||||
int64_t global_bs)
|
||||
{
|
||||
auto x_shape = x.sizes();
|
||||
int bs = x_shape[0];
|
||||
int h = x_shape[1];
|
||||
|
||||
at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta));
|
||||
|
||||
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
|
||||
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
|
||||
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta));
|
||||
|
||||
return {output, ep_recv_count};
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
@@ -255,6 +286,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
|
||||
// Grouped matmul swiglu quant weight nz tensor list
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
|
||||
// dispatch_gmm_combine_decode meta implementation
|
||||
ops.impl("dispatch_gmm_combine_decode", &vllm_ascend::meta::dispatch_gmm_combine_decode_meta);
|
||||
// batch_matmul_transpose
|
||||
ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose);
|
||||
// Lightning indexer
|
||||
|
||||
@@ -163,6 +163,7 @@ cd ..
|
||||
```
|
||||
|
||||
vllm-ascend will build custom operators by default. If you don't want to build it, set `COMPILE_CUSTOM_KERNELS=0` environment to disable it.
|
||||
If you are building custom operators for Atlas A3, you should run `git submodule update --init --recursive` manually, or ensure your environment has Internet access.
|
||||
:::
|
||||
|
||||
```{note}
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
|
||||
|
||||
def redirect_output(log_file_path):
|
||||
log_path = Path(LOG_NAME) / log_file_path
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
f = open(LOG_NAME + "/" + log_file_path, "w")
|
||||
os.dup2(f.fileno(), sys.stdout.fileno())
|
||||
os.dup2(f.fileno(), sys.stderr.fileno())
|
||||
return f
|
||||
|
||||
|
||||
def permute_weight(w: torch.Tensor, tile_n):
|
||||
*dims, n = w.shape
|
||||
order = list(range(len(dims))) + [-2, -3, -1]
|
||||
return w.reshape(*dims, 2, n // tile_n,
|
||||
tile_n // 2).permute(order).reshape(*dims,
|
||||
n).contiguous()
|
||||
|
||||
|
||||
def from_inclusive_prefix_sum(pref):
|
||||
if isinstance(pref, torch.Tensor):
|
||||
if pref.numel() == 0:
|
||||
return pref
|
||||
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
|
||||
|
||||
if not pref:
|
||||
return []
|
||||
out = [pref[0]]
|
||||
for i in range(1, len(pref)):
|
||||
out.append(pref[i] - pref[i - 1])
|
||||
return out
|
||||
|
||||
|
||||
def output_to_file(rank_id):
|
||||
return False
|
||||
|
||||
|
||||
class DecodeMoeOps(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__()
|
||||
self.ep_hcomm_info = ep_hcomm_info
|
||||
self.batch_size = batch_size
|
||||
self.token_hidden_size = token_hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.ep_world_size = ep_world_size
|
||||
self.moe_expert_num = moe_expert_num
|
||||
self.global_rank_id = global_rank_id
|
||||
self.shared_expert_rank_num = shared_expert_rank_num
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
self.ep_recv_count_size = self.local_expert_num * ep_world_size
|
||||
self.gmm1_weight = torch.empty([
|
||||
self.local_expert_num, self.token_hidden_size,
|
||||
self.moe_intermediate_size * 2
|
||||
])
|
||||
self.gmm1_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.moe_intermediate_size * 2])
|
||||
self.gmm2_weight = torch.empty([
|
||||
self.local_expert_num, self.moe_intermediate_size,
|
||||
self.token_hidden_size
|
||||
])
|
||||
self.gmm2_weight_scale = torch.empty(
|
||||
[self.local_expert_num, self.token_hidden_size])
|
||||
self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
|
||||
|
||||
|
||||
class SmallOps(DecodeMoeOps):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
expert_scales=expert_scales,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
group_tp=self.tp_hcomm_info,
|
||||
tp_world_size=1,
|
||||
tp_rank_id=0,
|
||||
expert_shard_type=0,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
quant_mode=2,
|
||||
global_bs=self.batch_size * self.ep_world_size,
|
||||
expert_token_nums_type=1, # 0代表前缀和,1代表各自数量
|
||||
)
|
||||
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
|
||||
output_dtype = x.dtype
|
||||
|
||||
y1_int32 = torch_npu.npu_grouped_matmul(
|
||||
x=[expand_x],
|
||||
weight=[self.gmm1_weight],
|
||||
split_item=3,
|
||||
group_list_type=1, # 默认为0,代表前缀和形式
|
||||
group_type=0, # 0代表m轴分组
|
||||
group_list=expert_token_nums,
|
||||
output_dtype=torch.int32)[0]
|
||||
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=y1_int32,
|
||||
weight_scale=self.gmm1_weight_scale.to(torch.float32),
|
||||
activation_scale=dynamic_scales,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_token_nums,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
y2 = torch_npu.npu_grouped_matmul(x=[y1],
|
||||
weight=[self.gmm2_weight],
|
||||
scale=[self.gmm2_weight_scale],
|
||||
per_token_scale=[y1_scale],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_nums,
|
||||
output_dtype=output_dtype)[0]
|
||||
combine_output = torch_npu.npu_moe_distribute_combine_v2(
|
||||
expand_x=y2,
|
||||
expert_ids=expert_ids,
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
ep_send_counts=ep_send_counts,
|
||||
expert_scales=expert_scales,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
tp_send_counts=tp_send_counts,
|
||||
expand_scales=expand_scales,
|
||||
group_tp=self.tp_hcomm_info,
|
||||
tp_world_size=1,
|
||||
tp_rank_id=0,
|
||||
expert_shard_type=0,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
|
||||
|
||||
|
||||
class FusionOp(DecodeMoeOps):
|
||||
|
||||
def __init__(self,
|
||||
gmm1_weight,
|
||||
gmm1_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
ep_hcomm_info,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0):
|
||||
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
|
||||
gmm2_weight_scale, ep_hcomm_info, batch_size,
|
||||
token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.ND)
|
||||
gmm1_weight.add_(0)
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
|
||||
gmm2_weight = torch_npu.npu_format_cast(
|
||||
gmm2_weight.transpose(1, 2).contiguous(),
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
|
||||
gmm1_weight_scale = gmm1_weight_scale.float()
|
||||
gmm2_weight_scale = gmm2_weight_scale.float()
|
||||
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
gmm1_permuted_weight=self.gmm1_weight,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
expert_scales=expert_scales,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_rank_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
moe_expert_num=self.moe_expert_num,
|
||||
shared_expert_num=1,
|
||||
shared_expert_rank_num=self.shared_expert_rank_num,
|
||||
quant_mode=0,
|
||||
global_bs=self.batch_size * self.ep_world_size)
|
||||
return output
|
||||
|
||||
|
||||
def generate_datas(batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
global_rank_id,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False):
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
actual_bs = int(
|
||||
torch.randint(1, batch_size, [1]).item(
|
||||
) if enable_dynamic_bs else batch_size)
|
||||
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
gmm1_input_dim = token_hidden_size
|
||||
gmm1_output_dim = moe_intermediate_size * 2
|
||||
gmm2_input_dim = moe_intermediate_size
|
||||
gmm2_output_dim = token_hidden_size
|
||||
x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5
|
||||
expert_ids = torch.arange(
|
||||
global_rank_id * batch_size * top_k,
|
||||
global_rank_id * batch_size * top_k + actual_bs * top_k).to(
|
||||
torch.int32).view(actual_bs, top_k)
|
||||
expert_ids = expert_ids % moe_expert_num
|
||||
if is_shared_expert:
|
||||
gmm1_weight = torch.ones([
|
||||
local_expert_num, gmm1_input_dim, gmm1_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm2_weight = torch.ones([
|
||||
local_expert_num, gmm2_input_dim, gmm2_output_dim
|
||||
]).to(torch.int8) * 4
|
||||
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
|
||||
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
|
||||
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
|
||||
]) * 0.0015
|
||||
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
|
||||
]) * 0.0015
|
||||
else:
|
||||
gmm1_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
|
||||
gmm2_weight = torch.randint(
|
||||
-16, 16,
|
||||
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
|
||||
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
|
||||
]) * 0.003 + 0.0015
|
||||
expert_scales = torch.rand(actual_bs, top_k)
|
||||
if test_bfloat16:
|
||||
x = x.bfloat16()
|
||||
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
|
||||
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
|
||||
else:
|
||||
x = x.half()
|
||||
smooth_sales = None
|
||||
return (x, expert_ids, smooth_sales, expert_scales), \
|
||||
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
||||
actual_bs
|
||||
|
||||
|
||||
def run_once(local_rank_id,
|
||||
batch_size,
|
||||
token_hidden_size,
|
||||
moe_intermediate_size,
|
||||
ep_world_size,
|
||||
moe_expert_num,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
device_id = local_rank_id % 16
|
||||
torch_npu.npu.set_device(device_id)
|
||||
|
||||
# 初始化分布式环境
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500" # 端口号随意
|
||||
dist.init_process_group(backend="hccl",
|
||||
rank=local_rank_id,
|
||||
world_size=ep_world_size)
|
||||
ep_ranks_list = list(np.arange(0, ep_world_size))
|
||||
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
|
||||
ep_hcomm_info_fused = ep_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(local_rank_id)
|
||||
ep_hcomm_info_small = ep_group_small._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(local_rank_id)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
|
||||
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
input_datas, weight_datas, actual_bs = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
|
||||
input_datas = [
|
||||
data.npu() if data is not None else None for data in input_datas
|
||||
]
|
||||
weight_datas = [
|
||||
data.npu() if data is not None else None for data in weight_datas
|
||||
]
|
||||
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
|
||||
*parameter).npu() # type: ignore
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
|
||||
*parameter).npu() # type: ignore
|
||||
if test_graph:
|
||||
fused_ops = torch.compile(fused_ops, backend=npu_backend)
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
torch_npu.npu.synchronize(device_id)
|
||||
dist.destroy_process_group()
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||
torch.testing.assert_close(small_op_token_output.cpu(),
|
||||
fused_op_token_output.cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
fused_op_count_output.cpu())
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test():
|
||||
batch_size = 64
|
||||
token_hidden_size = 7168
|
||||
moe_intermediate_size = 2048
|
||||
ep_world_size = 16
|
||||
moe_expert_num = 64
|
||||
shared_expert_rank_num = 0
|
||||
top_k = 8
|
||||
test_bfloat16 = True
|
||||
enable_dynamic_bs = False
|
||||
test_graph = False
|
||||
args = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
|
||||
test_bfloat16, enable_dynamic_bs, test_graph)
|
||||
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)
|
||||
Reference in New Issue
Block a user