[Kernel] add custom op DispatchGmmCombineDecode (#4139)

#### What this PR does / why we need it?
add custom opapi DispatchGmmCombineDecode for A3, include kernel inpl,
python Api, pytest.

vLLM version: v0.11.0
vLLM main:
24d6314718


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangqiankun <wangqiankun13@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
GuoRen868
2025-12-06 17:33:14 +08:00
committed by GitHub
parent cb42564942
commit 4bd1030842
29 changed files with 7851 additions and 27 deletions

View File

@@ -138,3 +138,21 @@ jobs:
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
tests: ${{ matrix.test_config.tests }}
name: ${{ matrix.test_config.name }}
custom-ops-tests:
name: test ops
if: always() && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch')
needs: multi-node-tests
strategy:
fail-fast: false
matrix:
test_config:
- name: custom-op-dispatch_gmm_combine_decode
os: linux-aarch64-a3-16
tests: tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
with:
runner: ${{ matrix.test_config.os }}
vllm: v0.12.0
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
tests: ${{ matrix.test_config.tests }}
name: ${{ matrix.test_config.name }}

View File

@@ -3,8 +3,6 @@
ROOT_DIR=$1
SOC_VERSION=$2
git config --global --add safe.directory "$ROOT_DIR"
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
# ASCEND310P series
# currently, no custom aclnn ops for ASCEND310 series
@@ -13,11 +11,41 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
exit 0
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
# ASCEND910B (A2) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
# depdendency: catlass
git config --global --add safe.directory "$ROOT_DIR"
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass/include
if [[ ! -d "${CATLASS_PATH}" ]]; then
echo "depdendency catlass is missing, try to fetch it..."
if ! git submodule update --init --recursive; then
echo "fetch failed"
exit 1
fi
fi
# depdendency: cann-toolkit file moe_distribute_base.h
HCCL_STRUCT_FILE_PATH=$(find -L "${ASCEND_TOOLKIT_HOME}" -name "moe_distribute_base.h" 2>/dev/null | head -n1)
if [ -z "$HCCL_STRUCT_FILE_PATH" ]; then
echo "cannot find moe_distribute_base.h file in CANN env"
exit 1
fi
# for dispatch_gmm_combine_decode
yes | cp "${HCCL_STRUCT_FILE_PATH}" "${ROOT_DIR}/csrc/dispatch_gmm_combine_decode/op_kernel"
# for dispatch_ffn_combine
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
TARGET_FILE="$TARGET_DIR/$(basename "$HCCL_STRUCT_FILE_PATH")"
echo "*************************************"
echo $HCCL_STRUCT_FILE_PATH
echo "$TARGET_DIR"
cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR"
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine;dispatch_gmm_combine_decode;"
SOC_ARG="ascend910_93"
else
# others
@@ -25,29 +53,6 @@ else
exit 0
fi
git submodule init
git submodule update
# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h
file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1)
if [ -z "$file_path" ]; then
echo "cannot find moe_distribute_base.h file in CANN env"
exit 1
fi
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
TARGET_FILE="$TARGET_DIR/$(basename "$file_path")"
echo "*************************************"
echo $file_path
echo "$TARGET_DIR"
cp "$file_path" "$TARGET_DIR"
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
# build custom ops
cd csrc

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
set(_DISPATCH_GMM_INC_OPTS)
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
list(APPEND _DISPATCH_GMM_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
else()
message(FATAL_ERROR "dependency catlass is missing, you can fetch it by running 'git submodule update --init --recursive'")
endif()
add_ops_compile_options(
OP_NAME DispatchGmmCombineDecode
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
${_DISPATCH_GMM_INC_OPTS}
)
target_sources(op_host_aclnnInner PRIVATE
dispatch_gmm_combine_decode_def.cpp
)
target_sources(opapi PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
target_sources(aclnn_ops_infer PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
endif ()
target_sources(optiling PRIVATE
dispatch_gmm_combine_decode_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
dispatch_gmm_combine_decode_proto.cpp
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,101 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <string.h>
#include "graph/types.h"
#include "aclnn/opdev/platform.h"
#include "aclnn_dispatch_gmm_combine_decode.h"
enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
#ifdef __cplusplus
extern "C" {
#endif
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}
aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
} else {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
}
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,51 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE
#define DISPATCH_GMM_COMBINE_DECODE
#include "aclnn/acl_meta.h"
#ifdef __cplusplus
extern "C" {
#endif
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,83 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/op_def_registry.h"
namespace ops {
class DispatchGmmCombineDecode : public OpDef
{
public:
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();
this->Attr("moe_expert_num").Int();
this->Attr("share_expert_num").Int();
this->Attr("share_expert_rank_num").Int();
this->Attr("quant_mode").Int();
this->Attr("global_bs").Int();
this->MC2().HcclGroup({"group_ep"});
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(DispatchGmmCombineDecode);
} // namespace ops

View File

@@ -0,0 +1,95 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdint>
#include "log/ops_log.h"
#include "error/ops_error.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
namespace ge {
constexpr uint32_t EXPAND_X_INDEX = 0;
constexpr uint32_t EXPERT_IDS_INDEX = 1;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
static ge::graphStatus InferShape(gert::InferShapeContext *context)
{
const char *nodeName = context->GetNodeName();
// infer output shape
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
recvCountOutShape == nullptr) {
return GRAPH_FAILED;
}
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
return GRAPH_FAILED;
}
int bs = expertIdsShape->GetDim(0);
int h = expandXShape->GetDim(1);
expandXOutShape->SetDimNum(expandXShape->GetDimNum());
expandXOutShape->SetDim(0, bs);
expandXOutShape->SetDim(1, h);
// infer recvCount shape
auto attrs = context->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
OPS_ERR_IF(epRankIdPtr == nullptr, OPS_LOG_E(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNumPtr == nullptr, OPS_LOG_E(nodeName, "moeExpertNumPtr is nullptr."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(epRankSizePtr == nullptr, OPS_LOG_E(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED);
OPS_ERR_IF(sharedExpertRankNumPtr == nullptr, OPS_LOG_E(nodeName, "sharedExpertRankNumPtr is nullptr."),
return ge::GRAPH_FAILED);
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
recvCountOutShape->SetDimNum(1);
bool isShareExpert = (epRankId < sharedExpertRankNum);
if (isShareExpert) {
recvCountOutShape->SetDim(0, epRankSize);
} else {
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
}
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
return ge::GRAPH_SUCCESS;
}
IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType);
} // namespace ge

View File

@@ -0,0 +1,335 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdio>
#include <cstdint>
#include <string>
#include "log/ops_log.h"
#include "error/ops_error.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/hccl/hccl_tiling.h"
using namespace ge;
namespace {
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr uint32_t GM_ALIGN_SIZE = 512;
constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2;
constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024;
constexpr uint32_t CUBE_WORKSPACE_STAGE = 4;
constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024;
constexpr uint32_t INPUT_X_INDEX = 0;
constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1;
constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2;
constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3;
constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4;
constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
} // namespace
namespace optiling {
static size_t CeilUp(size_t x, size_t y)
{
return (x + y - 1) / y * y;
}
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
return ge::GRAPH_FAILED);
OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
return ge::GRAPH_FAILED);
uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
OPS_ERR_IF(
tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
OPS_LOG_E(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
return ge::GRAPH_FAILED);
uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
OPS_ERR_IF(
gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
OPS_LOG_E(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
return ge::GRAPH_FAILED);
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
OPS_ERR_IF(topK > SUPPORT_TOP_K, OPS_LOG_E(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
return ge::GRAPH_FAILED);
uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
if (globalBatchSize == 0) {
globalBatchSize = epRankSize * batchSize;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize;
} else {
OPS_ERR_IF(globalBatchSize < 0, OPS_LOG_E(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(globalBatchSize % epRankSize > 0,
OPS_LOG_E(nodeName, "globalBatchSize must be divisible by epRankSize."),
return ge::GRAPH_FAILED);
}
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t recvAivNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aivNum / 2;
OPS_ERR_IF(
moeExpertNumPerRank > recvAivNum,
OPS_LOG_E(nodeName, "moeExpertNumPerRank must <= (aivNum/2)(%u), but got %u", recvAivNum, moeExpertNumPerRank),
return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
{
auto attrs = context->GetAttrs();
OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
auto epRankIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_ID_INDEX);
auto moeExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_MOE_EXPERT_NUM_INDEX);
auto sharedExpertNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_NUM_INDEX);
auto sharedExpertRankNumPtr = attrs->GetAttrPointer<int64_t>(ATTR_SHARE_EXPERT_RANK_NUM_INDEX);
auto quantModePtr = attrs->GetAttrPointer<int64_t>(ATTR_QUANT_MODE_INDEX);
auto globalBsPtr = attrs->GetAttrPointer<int64_t>(ATTR_GLOBAL_BS_INDEX);
uint32_t epRankSize = static_cast<uint32_t>(*epRankSizePtr);
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t moeExpertNum = static_cast<uint32_t>(*moeExpertNumPtr);
uint32_t sharedExpertNum = static_cast<uint32_t>(*sharedExpertNumPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
OPS_ERR_IF(epRankId < 0, OPS_LOG_E(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(epRankId >= epRankSize, OPS_LOG_E(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OPS_LOG_E(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum <= 0, OPS_LOG_E(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
OPS_ERR_IF(sharedExpertNum != 1, OPS_LOG_E(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
OPS_LOG_E(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."),
return ge::GRAPH_FAILED);
groupEp = std::string(groupEpPtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum;
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast<uint32_t>(*quantModePtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast<uint32_t>(*globalBsPtr);
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank;
return ge::GRAPH_SUCCESS;
}
static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp)
{
const char *nodeName = context->GetNodeName();
OPS_LOG_D(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str());
uint32_t opType = OP_TYPE_ALL_TO_ALL;
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
std::string algConfigAllGatherStr = "AllGather=level0:ring";
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr);
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
}
static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
size_t *workSpaces = context->GetWorkspaceSizes(1);
OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
size_t maxTokenNum;
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
uint32_t maxBatchSize = globalBs / epRankSize;
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
if (epRankId < sharedExpertRankNum) {
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
} else {
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
}
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t CVSwapBufferSize =
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context)
{
const char *nodeName = context->GetNodeName();
DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData<DispatchGmmCombineDecodeTilingData>();
OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
std::string groupEp = "";
const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX);
OPS_ERR_IF(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "x shape dims must be 2, but current dim num is %lu.",
xStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize;
const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize;
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX);
OPS_ERR_IF(expertIdsStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIds shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.",
expertIdsStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
context->SetTilingKey(0);
} else {
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
}
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context)
{
ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context);
return ret;
}
struct DispatchGmmCombineDecodeCompileInfo {};
ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(DispatchGmmCombineDecode)
.Tiling(DispatchGmmCombineDecodeTilingFunc)
.TilingParse<DispatchGmmCombineDecodeCompileInfo>(TilingParseForDispatchGmmCombineDecode);
} // namespace optiling

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "dispatch_gmm_combine_decode.h"
#include <kernel_operator.h>
#include "lib/matmul_intf.h"
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspace, GM_ADDR tiling)
{
icache_preload(8);
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
op.Process();
}
}

View File

@@ -0,0 +1,433 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE_H
#define DISPATCH_GMM_COMBINE_DECODE_H
#include "lib/matmul_intf.h"
#include <kernel_operator.h>
#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/epilogue/tile/tile_broadcast_mul.hpp"
#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp"
#include "catlass/epilogue/tile/tile_swizzle.hpp"
#include "catlass/gemm/block/block_swizzle.hpp"
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h"
#include "catlass/gemm/gemm_type.hpp"
#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h"
#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h"
#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h"
#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h"
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h"
#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h"
#include "dispatch_gmm_combine_decode_tiling.h"
#include "dispatch_gmm_combine_decode_base.h"
using namespace Catlass;
using MmadAtlasA2Custom =
Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
CUSTOM_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG,
CUSTOM_ENABLE_SHUFFLE_K>;
using Gmm1L1TileShape = GemmShape<GMM1_L1M, GMM1_L1N, GMM1_L1K>;
using Gmm1L0TileShape = GemmShape<GMM1_L1M, GMM1_L1N, GMM1_L0K>;
using Gmm1EpilogueTileShape = MatrixShape<GMM1_EPIM, Gmm1L1TileShape::N>;
using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM1_SWIZZLE_OFFSET, GMM1_SWIZZLE_DIRECTION>;
using Gmm2L1TileShape = GemmShape<GMM2_L1M, GMM2_L1N, GMM2_L1K>;
using Gmm2L0TileShape = GemmShape<Gmm2L1TileShape::M, Gmm2L1TileShape::N, GMM2_L0K>;
using Gmm2EpilogueTileShape = MatrixShape<GMM2_EPIM, Gmm2L1TileShape::N>;
using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM2_SWIZZLE_OFFSET, GMM2_SWIZZLE_DIRECTION>;
using Gmm2DispatchPolicy =
Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA<CUSTOM_PRELOAD_STAGES, GMM2_L1A_STAGES, GMM2_L1B_STAGES,
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = DispatchPolicy_;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using XType = XType_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::zN>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu<ubStages, 0>;
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<float, layout::RowMajor>;
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using EpilogueTileShape = EpilogueTileShape_;
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
using TileBroadcastOneBlk =
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
using TileOneBlkColumnBroadcastMul =
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
using BlockScheduler = BlockScheduler_;
// kernel level
using ElementGroupList = int64_t;
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
typename GemmKernel::Params params{problemShape,
groupCount,
gmGroupList,
gmA,
layoutA,
gmB,
layoutB,
gmScale,
layoutScale,
gmPerTokenScale,
layoutPerTokenScale,
gmD,
layoutD,
gmDequantScale,
layoutDequantScale,
gmWorkspace,
gmX,
debugGm,
gmexpertIds,
gmExpandIdx,
gmEpSendCount,
gmResvered,
gmOutputRecvCount,
epRankSize,
epRankId,
moeExpertNum,
moeExpertNumPerRank,
sharedExpertNum,
sharedExpertRankNum,
quantMode,
globalBs,
bs,
topK,
tokenLen};
// call a kernel
GemmKernel gemm;
gemm(params);
} else {
typename GemmKernel::Params params{problemShape,
groupCount,
gmGroupList,
gmA,
layoutA,
gmB,
layoutB,
gmScale,
layoutScale,
gmPerTokenScale,
layoutPerTokenScale,
gmD,
layoutD,
gmDequantScale,
layoutDequantScale,
gmWorkspace};
// call a kernel
GemmKernel gemm;
gemm(params);
}
}
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmWorkspace, void *combiner)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = DispatchPolicy_;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::nZ>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine<ubStages, EXEC_FLAG>;
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using EpilogueTileShape = EpilogueTileShape_;
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
using TileBroadcastOneBlk =
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
using TileOneBlkColumnBroadcastMul =
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
using BlockScheduler = BlockScheduler_;
// kernel level
using ElementGroupList = int64_t;
using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
typename GemmKernel::Params params{
problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale,
layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner};
// call a kernel
GemmKernel gemm;
gemm(params);
}
template <TemplateMC2TypeClass>
class DispatchGmmCombineDecode
{
public:
__aicore__ inline DispatchGmmCombineDecode(){};
__aicore__ inline void Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process();
private:
GM_ADDR gmX_;
GM_ADDR gmexpertIds_;
GM_ADDR gmPermuteWeight1_;
GM_ADDR gmPermuteScale1_;
GM_ADDR gmWeight2_;
GM_ADDR gmScale2_;
GM_ADDR gmOutput_;
GM_ADDR gmOutputRecvCount_;
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
uint32_t m_{0};
uint32_t n_{0};
uint32_t k_{0};
uint32_t groupCount_{0};
uint32_t n2_{0};
uint32_t k2_{0};
uint32_t globalRankId_{0};
uint32_t winSizePerRank_{0};
uint32_t blockDim_{0};
uint32_t epRankSize_{0};
uint32_t epRankId_{0};
uint32_t moeExpertNum_{0};
uint32_t moeExpertNumPerRank_{0};
uint32_t sharedExpertNum_{0};
uint32_t sharedExpertRankNum_{0};
uint32_t quantMode_{0};
uint32_t globalBs_{0};
uint32_t bs_{0};
uint32_t maxBs_{0};
uint32_t topK_{0};
AscendC::TPipe *tpipe_{nullptr};
__gm__ HcclOpResParam *winContext_{nullptr};
const DispatchGmmCombineDecodeTilingData *tilingData_;
};
template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
tpipe_ = pipe;
blockDim_ = AscendC::GetBlockNum();
winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
gmSmoothScales_ = expert_smooth_scales; // not used now
gmX_ = x; // input token
gmexpertIds_ = expert_ids;
gmPermuteWeight1_ = gmm1_permuted_weight;
gmPermuteScale1_ = gmm1_permuted_weight_scale;
gmWeight2_ = gmm2_weight;
gmScale2_ = gmm2_weight_scale;
gmOutput_ = output;
gmOutputRecvCount_ = outputRecvCount;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
tilingData_ = tilingData;
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum;
sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode;
globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
maxBs_ = globalBs_ / epRankSize_;
bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
if (isShareExpert) {
m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
} else {
m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
}
n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
n2_ = k_;
k2_ = n_ / 2;
}
template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
{
GemmCoord gmm1ProblemShape{m_, n_, k_};
GemmCoord gmm2ProblemShape{m_, n2_, k2_};
layout::RowMajor layoutX1{m_, k_};
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(k_, n_);
layout::VectorLayout layoutScale1{n_};
layout::VectorLayout layoutPerTokenScale1{m_};
layout::RowMajor layoutX2{m_, k2_};
layout::nZ layoutWeight2 = layout::nZ::template MakeLayout<int8_t>(k2_, n2_);
layout::VectorLayout layoutScale2{n2_};
layout::VectorLayout layoutPerTokenScale2{m_};
layout::RowMajor layoutOutput{m_, n2_};
size_t workspaceOffset = 0;
constexpr int32_t resveredWorkSpaceSize = 256 * 1024;
GM_ADDR gmX2 = workspaceGM_;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(int8_t));
GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset;
GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (GMM1_L1M * GMM1_L1N) *
WORKSPACE_STAGES * sizeof(int32_t));
GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k2_ * sizeof(float));
GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t));
GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t));
GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t));
GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(int8_t));
GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * sizeof(float));
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(m_) * k_ * sizeof(ExpandXType));
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
if constexpr (EXEC_FLAG == 0) {
if constexpr (g_coreType == AscendC::AIV) {
AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process();
tpipe.Destroy();
icache_preload(8);
}
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};
if constexpr (g_coreType == AscendC::AIV) {
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
} else {
Arch::CrossCoreWaitFlag(gmm1AivFinished);
}
}
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2,
layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_);
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};
if constexpr (g_coreType == AscendC::AIV) {
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
} else {
Arch::CrossCoreWaitFlag(gmm1AivFinished);
}
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
if (g_coreType == AscendC::AIV) {
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
workspaceGM_, nullptr, tilingData_);
}
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2,
gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut,
layoutOutput, gmWorkspace, &combiner);
}
#endif // DISPATCH_GMM_COMBINE_DECODE_H

View File

@@ -0,0 +1,14 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/epilogue/block/block_epilogue.hpp"
#include "block_epilogue_per_token_dequant_swiglu.h"
#include "block_epilogue_per_token_dequant.hpp"

View File

@@ -0,0 +1,760 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP
#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP
#include "../../raw_distributed/cam_moe_distribute_combine.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/matrix_coord.hpp"
#define ENABLE_EP_SEND_COUNT_HASH 0
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class ScaleType_, class PerTokenScaleType_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, ScaleType_, PerTokenScaleType_,
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = typename ScaleType_::Element;
using LayoutScale = typename ScaleType_::Layout;
using ElementPerTokenScale = typename PerTokenScaleType_::Element;
using LayoutPerTokenScale = typename PerTokenScaleType_::Layout;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, int32_t> &&
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>) &&
std::is_same_v<ElementScale, ElementD> && std::is_same_v<ElementPerTokenScale, ElementD>,
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) +
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbScaleVMTE2List[i] = eventVMTE2++;
eventUbScaleMTE2VList[i] = eventMTE2V++;
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(float);
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * sizeof(float);
ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * BYTE_PER_BLK;
ubPerTokenMul = ubMul;
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
{
if (actualBlockShapeMNK.k() == 0) {
return;
}
callback();
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
uint32_t subblockNum = AscendC::GetSubBlockNum();
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
auto layoutUbPerTokenScale =
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
layoutGmTilePerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32);
tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32);
AscendC::PipeBarrier<PIPE_V>();
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb);
AscendC::PipeBarrier<PIPE_V>();
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbScaleVMTE2List[UB_STAGES];
int32_t eventUbScaleMTE2VList[UB_STAGES];
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32;
AscendC::LocalTensor<float> ubScaleFp32;
AscendC::LocalTensor<float> ubMul;
AscendC::LocalTensor<float> ubPerTokenScaleFp32;
AscendC::LocalTensor<float> ubPerTokenScaleFp32Brcb;
AscendC::LocalTensor<float> ubPerTokenMul;
TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
CopyUbToGmD copyUbToGmD;
};
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, Gemm::GemmType<float, LayoutScale_>,
Gemm::GemmType<float, LayoutPerTokenScale_>, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_,
TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, int32_t> &&
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE void AlignUbOffset()
{
size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1);
if (ubMask != 0) {
ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask;
}
}
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo,
Params const &params = Params{})
: resource(resource), calcInfo(calcInfo), params(params)
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbScaleVMTE2List[i] = eventVMTE2++;
eventUbScaleMTE2VList[i] = eventMTE2V++;
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * BYTE_PER_BLK;
ubPerTokenMul = ubCFp32;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AlignUbOffset();
epSendCountLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t);
AlignUbOffset();
AscendC::GlobalTensor<int32_t> epSendCountGM;
epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_);
uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_;
AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast<uint32_t>(epSendCountSize * sizeof(uint32_t)),
0U, 0U, 0U};
AscendC::DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
#if ENABLE_EP_SEND_COUNT_HASH
tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
uint32_t maxGroupSendCount = 0;
uint32_t groupSendCount = 0;
for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) {
uint32_t prevGroupSendCount = groupSendCount;
groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1);
if (maxGroupSendCount < groupSendCount - prevGroupSendCount) {
maxGroupSendCount = groupSendCount - prevGroupSendCount;
}
}
ubOffset += maxGroupSendCount * sizeof(int32_t);
AlignUbOffset();
// assert: ubOffset <= AscendC::TOTAL_UB_SIZE or
// AscendC::TOTAL_VEC_LOCAL_SIZE
#endif
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U)
{
return (GM_ADDR)((calcInfo.epRankId_ == rankId)
? calcInfo.epWinContext_->localWindowsIn
: ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsIn) +
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
}
#if ENABLE_EP_SEND_COUNT_HASH
CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen)
{
constexpr uint32_t DUPLICATE_MASK_COUNT = 8;
uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1));
if (hashOffsetMask != 0) {
uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask;
if (copyLen < remainMaskCount) {
remainMaskCount = copyLen;
}
uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask;
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, &copyMask, 1, 1,
DUPLICATE_MASK_COUNT);
hashOffset += remainMaskCount;
copyLen -= remainMaskCount;
}
if (copyLen > 0) {
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen);
hashOffset += copyLen;
}
}
#endif
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
{
if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) {
remoteEpRank = calcInfo.epRankId_;
localEpRank = epRank;
} else {
remoteEpRank = epRank;
localEpRank = calcInfo.epRankId_;
}
}
CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor<ElementD> &ubD, layout::RowMajor &layoutGmTileD,
LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD)
{
const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD);
const uint32_t copyTokenSrcStride =
(layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD));
const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD);
int64_t offsetD = groupOffsetD + tileOffsetD;
uint32_t startToken = offsetD / calcInfo.axisH_;
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
uint32_t itToken = startToken;
uint32_t endToken = startToken + layoutGmTileD.shape(0);
#if ENABLE_EP_SEND_COUNT_HASH
uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken);
#else
constexpr uint32_t epRankStart = 0;
#endif
uint32_t sendCount =
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
uint32_t prevSendCount = sendCount;
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
if (prevSendCount <= itToken && itToken < sendCount) {
uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken;
AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride,
copyTokenDstStride, 0);
uint32_t remoteEpRank;
uint32_t localEpRank;
SetCombineSendEpRank(epRank, remoteEpRank, localEpRank);
GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) +
localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_;
AscendC::GlobalTensor<ElementD> rankWindow;
rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM);
AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset],
ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams);
itToken += copyTokenCount;
}
}
}
CATLASS_DEVICE
void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK,
GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK,
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutBlockC,
Callback &&callback = Callback{})
{
if (actualBlockShapeMNK.k() == 0) {
return;
}
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
expertOffset = expertIdx * calcInfo.epWorldSize_;
#if ENABLE_EP_SEND_COUNT_HASH
if (currentExpertIdx_ != expertIdx) {
uint32_t hashOffset = 0;
uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1);
for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) {
uint32_t prevSendCount = sendCount;
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount);
}
AscendC::SetFlag<AscendC::HardEvent::V_S>(eventVS);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(eventVS);
currentExpertIdx_ = expertIdx;
}
#endif
}
callback();
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
uint32_t subblockNum = AscendC::GetSubBlockNum();
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
auto layoutUbPerTokenScale =
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
layoutGmTilePerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
tileRowBroadcastMul(ubMul, ubCFp32, ubScale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
AscendC::PipeBarrier<PIPE_V>();
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb);
AscendC::PipeBarrier<PIPE_V>();
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
auto tileOffsetD = params.layoutD.GetOffset(tileOffset);
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD);
} else {
auto gmTileD = gmD[tileOffsetD];
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
Arch::Resource<ArchTag> &resource;
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbScaleVMTE2List[UB_STAGES];
int32_t eventUbScaleMTE2VList[UB_STAGES];
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
AscendC::LocalTensor<int32_t> epSendCountLocal_;
#if ENABLE_EP_SEND_COUNT_HASH
AscendC::LocalTensor<int32_t> tokenToEpRankHashLocal_;
uint32_t currentExpertIdx_{static_cast<uint32_t>(-1)};
#endif
size_t ubOffset{0};
int32_t eventVMTE2{0};
int32_t eventMTE2V{0};
int32_t eventMTE3V{0};
int32_t eventVMTE3{0};
int32_t eventVS{0};
int32_t eventMTE2S{0};
uint32_t expertOffset;
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32;
AscendC::LocalTensor<float> ubMul;
AscendC::LocalTensor<float> ubPerTokenScaleBrcb;
AscendC::LocalTensor<float> ubPerTokenMul;
TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block
#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP

View File

@@ -0,0 +1,326 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "../tile/tile_stride_muls.h"
#include "../tile/tile_stride_binary.h"
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>, CType_,
Gemm::GemmType<float, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, int32_t> && std::is_same_v<ElementD, float>,
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0,
"The per token scale granularity for word calculation must be 32 bytes aligned.");
static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts.");
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2;
using ChunkTileShape = MatrixShape<TileShape::ROW, CHUNK_TILE_COLUMN>;
using TileStrideMuls = Tile::TileStrideMuls<ArchTag, float, ChunkTileShape, ChunkTileShape, TileShape>;
using TileStrideDiv = Tile::TileStrideDiv<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
ChunkTileShape::COLUMN>;
using TileStrideMul = Tile::TileStrideMul<ArchTag, float, ChunkTileShape, ChunkTileShape::COLUMN, TileShape::COLUMN,
ChunkTileShape::COLUMN>;
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbScaleVMTE2List[i] = eventVMTE2++;
eventUbScaleMTE2VList[i] = eventMTE2V++;
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubTmpMxN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubTmpMx32B = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * BYTE_PER_BLK;
ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
{
if (0 == actualBlockShapeMNK.k()) {
return;
}
callback();
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto ubChunkTileStride = MakeCoord(static_cast<int64_t>(ChunkTileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = 0; // for 1C1V
uint32_t subblockNum = 1; // for 1C1V
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1);
auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1);
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
auto layoutUbPerTokenScale =
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
layoutGmTilePerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
AscendC::PipeBarrier<PIPE_V>();
tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B);
AscendC::PipeBarrier<PIPE_V>();
tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>();
tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN);
AscendC::PipeBarrier<PIPE_V>();
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride};
auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN];
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)];
auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbScaleVMTE2List[UB_STAGES];
int32_t eventUbScaleMTE2VList[UB_STAGES];
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubTmpMxN;
AscendC::LocalTensor<float> ubTmpMx32B;
AscendC::LocalTensor<float> ubTmpMxChunkN;
TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
TileStrideMuls tileStrideMuls;
TileStrideDiv tileStrideDiv;
TileStrideMul tileStrideMul;
CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/epilogue/dispatch_policy.hpp"
namespace Catlass::Epilogue {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
struct EpilogueAtlasA2PerTokenDequantSwiglu {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
};
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
struct EpilogueAtlasA2PerTokenDequantCombine {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
};
} // namespace Catlass::Epilogue

View File

@@ -0,0 +1,107 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/catlass.hpp"
namespace Catlass::Epilogue::Tile {
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
int64_t SRC1_STRIDE_>
struct TileStrideBinary {
using ArchTag = ArchTag_;
using ElementCompute = ElementCompute_;
using TileShape = TileShape_;
static constexpr int64_t DST_STRIDE = DST_STRIDE_;
static constexpr int64_t SRC0_STRIDE = SRC0_STRIDE_;
static constexpr int64_t SRC1_STRIDE = SRC1_STRIDE_;
static constexpr uint32_t MAX_REPEAT_TIMES = 255;
static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementCompute);
static constexpr uint32_t DST_BLK_NUM_PER_COLUMN = DST_STRIDE / ELE_NUM_PER_BLK;
static constexpr uint32_t SRC0_BLK_NUM_PER_COLUMN = SRC0_STRIDE / ELE_NUM_PER_BLK;
static constexpr uint32_t SRC1_BLK_NUM_PER_COLUMN = SRC1_STRIDE / ELE_NUM_PER_BLK;
static constexpr uint32_t ROW_NUM_PER_COMPUTE = MAX_REPEAT_TIMES;
static constexpr uint32_t COL_NUM_PER_COMPUTE = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute);
CATLASS_DEVICE
TileStrideBinary()
{
repeatParams.dstBlkStride = 1;
repeatParams.src0BlkStride = 1;
repeatParams.src1BlkStride = 1;
repeatParams.dstRepStride = DST_BLK_NUM_PER_COLUMN;
repeatParams.src0RepStride = SRC0_BLK_NUM_PER_COLUMN;
repeatParams.src1RepStride = SRC1_BLK_NUM_PER_COLUMN;
}
AscendC::BinaryRepeatParams repeatParams;
};
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
int64_t SRC1_STRIDE_>
struct TileStrideMul
: TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_> {
using Base = TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_>;
CATLASS_DEVICE
TileStrideMul() : Base() {}
CATLASS_DEVICE
void operator()(AscendC::LocalTensor<typename Base::ElementCompute> const &ubDst,
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc0,
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc1)
{
for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) {
uint32_t residueM = Base::TileShape::ROW - rowOffset;
uint8_t repeatTimes =
static_cast<uint8_t>((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM);
for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) {
uint32_t residueN = Base::TileShape::COLUMN - colOffset;
uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN;
AscendC::Mul(ubDst[rowOffset * Base::DST_STRIDE + colOffset],
ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset],
ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams);
}
}
}
};
template <class ArchTag_, class ElementCompute_, class TileShape_, int64_t DST_STRIDE_, int64_t SRC0_STRIDE_,
int64_t SRC1_STRIDE_>
struct TileStrideDiv
: TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_> {
using Base = TileStrideBinary<ArchTag_, ElementCompute_, TileShape_, DST_STRIDE_, SRC0_STRIDE_, SRC1_STRIDE_>;
CATLASS_DEVICE
TileStrideDiv() : Base() {}
CATLASS_DEVICE
void operator()(AscendC::LocalTensor<typename Base::ElementCompute> const &ubDst,
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc0,
AscendC::LocalTensor<typename Base::ElementCompute> const &ubSrc1)
{
for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) {
uint32_t residueM = Base::TileShape::ROW - rowOffset;
uint8_t repeatTimes =
static_cast<uint8_t>((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM);
for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) {
uint32_t residueN = Base::TileShape::COLUMN - colOffset;
uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN;
AscendC::Div(ubDst[rowOffset * Base::DST_STRIDE + colOffset],
ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset],
ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams);
}
}
}
};
} // namespace Catlass::Epilogue::Tile

View File

@@ -0,0 +1,59 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/catlass.hpp"
namespace Catlass::Epilogue::Tile {
template <class ArchTag_, class ElementCompute_, class TileShape_, class DstTileShape_, class SrcTileShape_>
struct TileStrideMuls {
using ArchTag = ArchTag_;
using ElementCompute = ElementCompute_;
using TileShape = TileShape_;
using DstTileShape = DstTileShape_;
using SrcTileShape = SrcTileShape_;
static_assert(DstTileShape::ROW == SrcTileShape::ROW && DstTileShape::ROW == TileShape::ROW, "Error");
CATLASS_DEVICE
TileStrideMuls() {}
CATLASS_DEVICE
void operator()(AscendC::LocalTensor<ElementCompute> const &ubDst,
AscendC::LocalTensor<ElementCompute> const &ubSrc, ElementCompute scalar)
{
constexpr uint32_t maxRepeatTimes = 255;
constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute);
constexpr uint32_t dstBlkNumPerColumn = DstTileShape::COLUMN / eleNumPerBlk;
constexpr uint32_t srcBlkNumPerColumn = SrcTileShape::COLUMN / eleNumPerBlk;
AscendC::UnaryRepeatParams repeatParams;
repeatParams.dstBlkStride = 1;
repeatParams.srcBlkStride = 1;
repeatParams.dstRepStride = dstBlkNumPerColumn;
repeatParams.srcRepStride = srcBlkNumPerColumn;
constexpr uint32_t rowNumPerCompute = maxRepeatTimes;
constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute);
for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) {
uint32_t residueM = TileShape::ROW - rowOffset;
uint8_t repeatTimes = static_cast<uint8_t>((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM);
for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) {
uint32_t residueN = TileShape::COLUMN - colOffset;
uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN;
AscendC::Muls(ubDst[rowOffset * DstTileShape::COLUMN + colOffset],
ubSrc[rowOffset * SrcTileShape::COLUMN + colOffset], scalar, mask, repeatTimes,
repeatParams);
}
}
}
};
} // namespace Catlass::Epilogue::Tile

View File

@@ -0,0 +1,13 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/gemm/block/block_mmad.hpp"
#include "block_mmad_preload_async_with_callback_resident_a.h"

View File

@@ -0,0 +1,420 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/helper.hpp"
namespace Catlass::Gemm::Block {
template <uint32_t PRELOAD_STAGES_, uint32_t L1A_STAGES_, uint32_t L1B_STAGES_, uint32_t L0A_STAGES_,
uint32_t L0B_STAGES_, uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_,
class L1TileShape_, class L0TileShape_, class AType_, class BType_, class CType_, class BiasType_,
class TileCopy_, class TileMmad_>
struct BlockMmad<
MmadAtlasA2PreloadAsyncWithCallbackResidentA<PRELOAD_STAGES_, L1A_STAGES_, L1B_STAGES_, L0A_STAGES_, L0B_STAGES_,
L0C_STAGES_, ENABLE_UNIT_FLAG_, ENABLE_SHUFFLE_K_>,
L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> {
public:
// Type Aliases
using DispatchPolicy =
MmadAtlasA2PreloadAsyncWithCallbackResidentA<PRELOAD_STAGES_, L1A_STAGES_, L1B_STAGES_, L0A_STAGES_,
L0B_STAGES_, L0C_STAGES_, ENABLE_UNIT_FLAG_, ENABLE_SHUFFLE_K_>;
using ArchTag = typename DispatchPolicy::ArchTag;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using ElementA = typename AType_::Element;
using LayoutA = typename AType_::Layout;
using ElementB = typename BType_::Element;
using LayoutB = typename BType_::Layout;
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using TileMmad = TileMmad_;
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
using CopyL0CToGm = typename TileCopy_::CopyL0CToGm;
using ElementAccumulator =
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
using LayoutCInL0 = layout::zN;
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES;
static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES;
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
// L1 tile size
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
// L0 tile size
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
// Check LayoutC
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
// Check L1TileShape
static_assert(L1A_TILE_SIZE * L1A_STAGES + L1B_TILE_SIZE * L1B_STAGES <= ArchTag::L1_SIZE,
"L1TileShape exceeding the L1 space!");
// Check L0TileShape
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(L1TileShape::M, L1TileShape::K);
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
{
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
InitL0C(resource);
}
CATLASS_DEVICE
~BlockMmad()
{
SynchronizeBlock();
for (uint32_t i = 0; i < L1A_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
}
for (uint32_t i = 0; i < L1B_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
}
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe,
Callback const &callbackAfterFixpipe)
{
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
bool useResidentA =
(kTileCount == L1A_STAGES) && (!isFirstLoad) && (gmBlockA.GetPhyAddr() == lastGmBlockA.GetPhyAddr());
isFirstLoad = false;
lastGmBlockA = gmBlockA;
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
uint32_t startTileIdx = 0;
if constexpr (ENABLE_SHUFFLE_K) {
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
}
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx)
: (startTileIdx + kLoopIdx - kTileCount);
uint32_t kActual =
(kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
// Emission load instruction from GM to L1
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
// Load first matrix A tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1AListId]);
if (!useResidentA) {
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
copyGmToL1A(l1ATensorList[l1AListId], gmTileA, L1A_LAYOUT, layoutTileA);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1AListId]);
// Load first matrix B tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1BListId]);
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
copyGmToL1B(l1BTensorList[l1BListId], gmTileB, L1B_LAYOUT, layoutTileB);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1BListId]);
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
if (preloadCount == PRELOAD_STAGES) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
}
// Store the current load status
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES)
? (l1TileMmadParamsId + preloadCount)
: (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
l1TileMmadParams.l1AListId = l1AListId;
l1TileMmadParams.l1BListId = l1BListId;
l1TileMmadParams.mRound = mRound;
l1TileMmadParams.nRound = nRound;
l1TileMmadParams.kActual = kActual;
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
if (kLoopIdx == kTileCount - 1) {
l1TileMmadParams.gmBlockC = gmBlockC;
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe;
l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe;
}
if (preloadCount < PRELOAD_STAGES) {
++preloadCount;
} else {
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
}
l1AListId = (l1AListId + 1 < L1A_STAGES) ? (l1AListId + 1) : 0;
l1BListId = (l1BListId + 1 < L1B_STAGES) ? (l1BListId + 1) : 0;
}
}
CATLASS_DEVICE
void SynchronizeBlock()
{
while (preloadCount > 0) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
--preloadCount;
}
}
private:
struct L1TileMmadParams {
uint32_t l1AListId;
uint32_t l1BListId;
uint32_t mRound;
uint32_t nRound;
uint32_t kActual;
bool isKLoopFirst;
bool isKLoopLast;
AscendC::GlobalTensor<ElementC> gmBlockC;
LayoutC layoutCInGm;
Callback callbackBeforeFixpipe;
Callback callbackAfterFixpipe;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
CATLASS_DEVICE
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
{
uint32_t l1AOffset = l1BufAddrStart;
for (uint32_t i = 0; i < L1A_STAGES; ++i) {
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
l1AEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
}
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1A_STAGES;
for (uint32_t i = 0; i < L1B_STAGES; ++i) {
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
l1BEventList[i] = i + L1A_STAGES;
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
}
CATLASS_DEVICE
void InitL0A(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
l0AEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
}
CATLASS_DEVICE
void InitL0B(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
l0BEventList[i] = i + L0A_STAGES;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
}
CATLASS_DEVICE
void InitL0C(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
l0CEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
}
CATLASS_DEVICE
void L1TileMmad(L1TileMmadParams const &params)
{
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
auto &l1ATensor = l1ATensorList[params.l1AListId];
auto &l1BTensor = l1BTensorList[params.l1BListId];
auto &l0CTensor = l0CTensorList[l0CListId];
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
if constexpr (!ENABLE_UNIT_FLAG) {
if (params.isKLoopFirst) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
}
}
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
uint32_t mPartActual =
(mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
uint32_t kPartActual =
(kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
auto &l0ATile = l0ATensorList[l0AListId];
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
if ((mPartIdx == 0) && (kPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1AListId]);
}
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1AListId]);
}
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
uint32_t nPartActual =
(nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
auto &l0BTile = l0BTensorList[l0BListId];
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
if ((kPartIdx == 0) && (nPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1BListId]);
}
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1BListId]);
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
// If the unit flag is enabled, the unit flag is set according to the calculation progress
uint8_t unitFlag = 0b00;
if constexpr (ENABLE_UNIT_FLAG) {
if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) &&
(nPartIdx == nPartLoop - 1)) {
unitFlag = 0b11;
} else {
unitFlag = 0b10;
}
}
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
}
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
}
}
if (params.isKLoopLast) {
auto layoutCInGm = params.layoutCInGm;
params.callbackBeforeFixpipe();
if constexpr (!ENABLE_UNIT_FLAG) {
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
} else {
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
}
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
params.callbackAfterFixpipe();
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1A_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1B_STAGES];
int32_t l1AEventList[L1A_STAGES];
int32_t l1BEventList[L1B_STAGES];
uint32_t l1AListId{0};
uint32_t l1BListId{0};
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
int32_t l0AEventList[L0A_STAGES];
uint32_t l0AListId{0};
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
int32_t l0BEventList[L0B_STAGES];
uint32_t l0BListId{0};
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
int32_t l0CEventList[L0C_STAGES_];
uint32_t l0CListId{0};
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
uint32_t l1TileMmadParamsId{0};
uint32_t preloadCount{0};
TileMmad tileMmad;
CopyGmToL1A copyGmToL1A;
CopyGmToL1B copyGmToL1B;
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
bool isFirstLoad{true};
AscendC::GlobalTensor<ElementA> lastGmBlockA;
};
} // namespace Catlass::Gemm::Block

View File

@@ -0,0 +1,28 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/gemm/dispatch_policy.hpp"
namespace Catlass::Gemm {
template <uint32_t PRELOAD_STAGES_, uint32_t L1A_STAGES_, uint32_t L1B_STAGES_, uint32_t L0A_STAGES_,
uint32_t L0B_STAGES_, uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
struct MmadAtlasA2PreloadAsyncWithCallbackResidentA : public MmadAtlasA2Async {
static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance
static constexpr uint32_t L1A_STAGES = L1A_STAGES_;
static constexpr uint32_t L1B_STAGES = L1B_STAGES_;
static constexpr uint32_t L0A_STAGES = L0A_STAGES_;
static constexpr uint32_t L0B_STAGES = L0B_STAGES_;
static constexpr uint32_t L0C_STAGES = L0C_STAGES_;
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
};
} // namespace Catlass::Gemm

View File

@@ -0,0 +1,355 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
#include "../../raw_distributed/cam_moe_distribute_combine.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
namespace Catlass::Gemm::Kernel {
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_,
uint32_t WORKSPACE_STAGES_, class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace
{
public:
using BlockMmad = BlockMmad_;
using ArchTag = typename BlockMmad::ArchTag;
using L1TileShape = typename BlockMmad::L1TileShape;
using ElementA = typename BlockMmad::ElementA;
using LayoutA = typename BlockMmad::LayoutA;
using ElementB = typename BlockMmad::ElementB;
using LayoutB = typename BlockMmad::LayoutB;
using ElementC = typename BlockMmad::ElementC;
using LayoutC = typename BlockMmad::LayoutC;
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using BlockEpilogue = BlockEpilogue_;
using ElementScale = typename BlockEpilogue::ElementScale;
using LayoutScale = typename BlockEpilogue::LayoutScale;
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
using ElementD = typename BlockEpilogue::ElementD;
using LayoutD = typename BlockEpilogue::LayoutD;
using EpilogueParams = typename BlockEpilogue::Params;
using BlockScheduler = BlockScheduler_;
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
using ElementGroupList = ElementGroupList_;
/// Parameters structure
struct Params {
// Data members
GemmCoord problemShape;
uint32_t problemCount;
__gm__ ElementGroupList_ *ptrGroupList;
__gm__ ElementA *ptrA;
LayoutA layoutA;
__gm__ ElementB *ptrB;
LayoutB layoutB;
__gm__ ElementScale *ptrScale;
LayoutScale layoutScale;
__gm__ ElementPerTokenScale *ptrPerTokenScale;
LayoutPerTokenScale layoutPerTokenScale;
__gm__ ElementD *ptrD;
LayoutD layoutD;
GM_ADDR ptrWorkspace;
void *combiner;
// Methods
CATLASS_DEVICE
Params() {}
CATLASS_DEVICE
Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_,
GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_,
LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_,
void *combiner_)
: problemShape(problemShape_),
problemCount(problemCount_),
ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)),
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)),
layoutA(layoutA_),
ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)),
layoutB(layoutB_),
ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)),
layoutScale(layoutScale_),
ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)),
layoutD(layoutD_),
ptrWorkspace(ptrWorkspace_),
combiner(combiner_)
{}
};
// Methods
CATLASS_DEVICE
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace()
{
Arch::FlagID flagId = 0;
for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) {
flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++);
flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++);
aicWaitFuncList[stageId] = {this, stageId};
aicSetFuncList[stageId] = {this, stageId};
}
}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE void operator()(Params const &params);
template <>
CATLASS_DEVICE void operator()<AscendC::AIC>(Params const &params)
{
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
// Represent the full gm
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer(params.ptrA);
AscendC::GlobalTensor<ElementB> gmB;
gmB.SetGlobalBuffer(params.ptrB);
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
uint32_t coreIdx = AscendC::GetBlockIdx();
uint32_t coreNum = AscendC::GetBlockNum();
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
uint32_t stageId = 0;
uint32_t stageUsed = 0;
uint32_t startCoreIdx = 0;
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK());
LayoutB layoutB = params.layoutB;
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
// Determine the starting loopIdx of the current core under the current
// groupIdx
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
// Loop through the matmul of each groupIdx
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
Callback callbackBeforeFixpipe{};
if (stageUsed == WORKSPACE_STAGES) {
callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]);
} else {
++stageUsed;
}
Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]);
// Compute initial location in logical coordinates
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
int64_t gmOffsetB = layoutB.GetOffset(offsetB);
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
// Compute block-scoped matrix multiply-add
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe);
} else {
callbackBeforeFixpipe();
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
gmC[gmOffsetC], layoutC, actualBlockShape);
callbackAfterFixpipe();
}
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
while (stageUsed > 0) {
uint32_t aivComputeStageId =
(stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed);
Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]);
--stageUsed;
}
}
template <>
CATLASS_DEVICE void operator()<AscendC::AIV>(Params const &params)
{
auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> *)params.combiner;
{
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
if (get_subblockid() == 0) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID);
}
}
BlockScheduler blockScheduler;
BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo());
uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum();
uint32_t coreNum = AscendC::GetBlockNum();
int64_t gmGroupOffsetScale = 0;
int64_t gmGroupOffsetPerTokenScale = 0;
int64_t gmGroupOffsetD = 0;
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
uint32_t stageId = 0;
uint32_t startCoreIdx = 0;
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
LayoutScale layoutScale = params.layoutScale;
LayoutPerTokenScale layoutPerTokenScale =
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
params.ptrD + gmGroupOffsetD,
layoutD};
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
blockEpilogue.UpdateParams(epilogueParams);
uint32_t coreLoops = blockScheduler.GetCoreLoops();
GemmCoord blockShapeMNK = L1TileShape::ToCoord();
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK);
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
auto gmBlockC = gmC[gmOffsetC];
auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN());
Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]);
blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC,
layoutBlockC);
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]);
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
gmGroupOffsetScale += inGroupProblemShape.n();
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
}
icache_preload(4);
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
if (get_subblockid() == 0) {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->AllToAllSend();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
} else {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->ReducePermute();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
}
} else {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->Process();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
}
}
private:
friend struct AicWaitFunc;
friend struct AicSetFunc;
struct AicWaitFunc {
using MatmulKernel =
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicWaitFunc() = default;
CATLASS_DEVICE
void operator()() const
{
Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]);
}
MatmulKernel *ptr{nullptr};
uint32_t stageId;
};
struct AicSetFunc {
using MatmulKernel =
GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicSetFunc() = default;
CATLASS_DEVICE
void operator()() const
{
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]);
}
MatmulKernel *ptr{nullptr};
uint32_t stageId;
};
Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES];
Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES];
AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES];
AicSetFunc aicSetFuncList[WORKSPACE_STAGES];
Arch::Resource<ArchTag> resource;
};
} // namespace Catlass::Gemm::Kernel
#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP

View File

@@ -0,0 +1,814 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H
#define CAM_MOE_DISTRIBUTE_COMBINE_H
#define OPT_RANK_OFFSET 512
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "../../dispatch_gmm_combine_decode_base.h"
#include "../../dispatch_gmm_combine_decode_tiling.h"
namespace MoeDistributeCombineImpl {
constexpr uint8_t BUFFER_NUM = 2; // multi-buf
constexpr uint32_t STATE_OFFSET = 512;
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
constexpr uint32_t RANK_SIZE_ON_WIN_512 = 512 * 1024;
constexpr uint32_t RANK_SIZE_ON_WIN_256 = 256 * 1024;
constexpr uint32_t TP_RANK_SIZE_ON_WIN = 0;
constexpr uint32_t UB_ALIGN = 32;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint8_t EP_DOMAIN = 0;
constexpr uint8_t TP_DOMAIN = 1;
constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024;
constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024;
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
template <AscendC::HardEvent event>
__aicore__ inline void SyncFunc()
{
int32_t eventID = static_cast<int32_t>(GetTPipePtr()->FetchEventID(event));
AscendC::SetFlag<event>(eventID);
AscendC::WaitFlag<event>(eventID);
}
using namespace AscendC;
struct CombineCalcInfo {
uint64_t expertPerSizeOnWin_;
uint32_t epRankId_;
uint32_t epWorldSize_;
uint32_t moeExpertPerRankNum_;
uint32_t sharedExpertRankNum_;
uint32_t axisH_;
uint32_t moeSendNum_;
bool isShardExpert_;
GM_ADDR epSendCount_;
__gm__ HcclOpResParam *epWinContext_;
uint64_t winDataSizeOffset_;
};
template <TemplateMC2TypeClass>
class CamMoeDistributeCombine
{
public:
__aicore__ inline CamMoeDistributeCombine(){};
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process();
__aicore__ inline void AllToAllSend();
__aicore__ inline void ReducePermute();
__aicore__ inline CombineCalcInfo &GetCalcInfo()
{
return calcInfo_;
}
__aicore__ inline void TPipeSet(AscendC::TPipe *pipe)
{
tpipe_ = pipe;
}
private:
__aicore__ inline void InitStatusTargetSum();
__aicore__ inline void AlltoAllBuffInit();
__aicore__ inline void ReduceScatterTrans();
__aicore__ inline void SetWaitTpStatusAndDisPatch();
__aicore__ inline void CustomAdd(LocalTensor<ExpandXType> &dst, LocalTensor<ExpandXType> &src0,
LocalTensor<ExpandXType> &src1, uint32_t dataCnt);
__aicore__ inline void ExpertAlltoAllDispatchInnerCopyAdd(uint32_t tokenNumLoop, uint32_t srcStartTokenIdx,
uint32_t ep, uint32_t expertIdx);
__aicore__ inline void ExpertAlltoAllDispatchCopyAdd();
__aicore__ inline void LocalWindowCopy();
__aicore__ inline void BuffInit();
__aicore__ inline void SplitCoreCal();
__aicore__ inline void SetStatus();
__aicore__ inline void WaitDispatch();
__aicore__ GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t domain, const uint8_t expertLocalId = 0U)
{
if (domain == EP_DOMAIN) {
return (GM_ADDR)((epRankId_ == rankId)
? epWinContext_->localWindowsIn
: ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsIn) +
winDataSizeOffset_ + expertLocalId * expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
} else {
return (GM_ADDR)((tpRankId_ == rankId)
? tpWinContext_->localWindowsIn
: ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsIn) +
winDataSizeOffset_ + rankId * OPT_RANK_OFFSET;
}
}
__aicore__ GM_ADDR GetWinStateAddrByRankId(const int32_t rankId, const uint8_t domain)
{
if (domain == EP_DOMAIN) {
return (GM_ADDR)((epRankId_ == rankId)
? epWinContext_->localWindowsExp
: ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsExp) +
dataState_ * WIN_STATE_OFFSET;
} else {
return (GM_ADDR)((tpRankId_ == rankId)
? tpWinContext_->localWindowsExp
: ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsExp) +
dataState_ * WIN_STATE_OFFSET;
}
}
__aicore__ inline uint32_t MIN(uint32_t x, uint32_t y)
{
return (x < y) ? x : y;
}
__aicore__ static void DoCombineRecv(void *ptr)
{
auto *combiner = (CamMoeDistributeCombine<TemplateMC2TypeFunc> *)ptr;
combiner->ReducePermute();
}
TPipe *tpipe_{nullptr};
GlobalTensor<ExpandXType> expandXGM_;
GlobalTensor<ExpandIdxType> expertIdsGM_;
GlobalTensor<ExpandIdxType> expandIdxGM_;
GlobalTensor<ExpandIdxType> epSendCountGM_;
GlobalTensor<ExpandIdxType> tpSendCountGM_;
GlobalTensor<float> expandScalesGM_;
GlobalTensor<ExpandXType> expandOutGlobal_;
GlobalTensor<ExpandXType> rankWindow_;
GlobalTensor<int32_t> rankStates_;
GlobalTensor<float> epStatusSpaceGlobalTensor_;
GlobalTensor<float> tpStatusSpaceGlobalTensor_;
GlobalTensor<ExpandXType> tpRankWindow_;
GlobalTensor<ExpandXType> rowTmpGlobal_;
GM_ADDR workspaceGM_;
GM_ADDR epWindowGM_;
GM_ADDR epStatusSpaceGm_;
GM_ADDR tpWindowGM_;
GM_ADDR tpStatusSpaceGm_;
GM_ADDR stateGM_;
LocalTensor<ExpandXType> winTpSendCountTensor_;
LocalTensor<ExpandXType> gmTpSendCountTensor_;
LocalTensor<ExpandXType> outTensor_;
LocalTensor<float> winTpSendCountFloatTensor_;
LocalTensor<float> gmTpSendCountFloatTensor_;
LocalTensor<ExpandIdxType> epSendCountLocal_;
CombineCalcInfo calcInfo_;
uint32_t axisBS_{0};
uint32_t axisMaxBs_{0};
uint32_t axisH_{0};
uint32_t axisK_{0};
uint32_t aivNum_{0};
uint32_t epWorldSize_{0};
uint32_t tpWorldSize_{0};
uint32_t epRankId_{0};
uint32_t tpRankId_{0};
uint32_t coreIdx_{0}; // aiv id
uint32_t sharedExpertRankNum_{0};
uint32_t moeExpertNum_{0};
uint32_t moeExpertPerRankNum_{0};
uint32_t moeSendNum_{0}; // moeExpertPerRankNum_ * epWorldSize_
uint32_t tpScatterNum_{0};
uint32_t firstTpTokenEndIdx_{0};
uint32_t firstTpTokenEndOffset_{0};
uint32_t endTok_{0};
__gm__ HcclOpResParam *epWinContext_{nullptr};
__gm__ HcclOpResParam *tpWinContext_{nullptr};
uint32_t epDataOffsetOnWin_{0};
uint32_t tpDataOffsetOnWin_{0};
uint32_t epStateOffsetOnWin_{0};
uint32_t tpStateOffsetOnWin_{0};
uint32_t axisHFloatSize_{0};
uint32_t axisHExpandXTypeSize_{0};
uint32_t bsKNum_{0};
uint32_t startRankId_{0};
uint32_t endRankId_{0};
uint32_t sendRankNum_{0};
uint32_t ubSize_{0};
uint32_t dataState_{0};
uint32_t stateOffset_{0};
uint64_t winDataSizeOffset_{0};
uint64_t expertPerSizeOnWin_{0};
uint64_t totalWinSize_{0};
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> moeQueue_;
TQue<QuePosition::VECIN, 1> moeSumQueue_;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> gmTpSendCountQueue_;
TQue<QuePosition::VECIN, 1> gmTpSendCountInQueue_;
TQue<QuePosition::VECIN, 1> winTpSendCountInQueue_;
TQue<QuePosition::VECOUT, 1> xOutQueue_;
TBuf<> readStateBuf_;
TBuf<> expertIdsBuf_;
TBuf<> expandScalesBuf_;
TBuf<> rowTmpFloatBuf_;
TBuf<> sumFloatBuf_;
TBuf<> mulBuf_;
TBuf<> sendCountBuf_;
TBuf<> indexCountsBuf_;
TBuf<> winTpSendCountFloatBuf_;
TBuf<> gmTpSendCountFloatBuf_;
TBuf<> tokenBuf_;
TBuf<> statusBuf_;
TBuf<> gatherMaskOutBuf_; // gather mask output buf
TBuf<> gatherTmpBuf_;
TBuf<> statusSumOutBuf_;
float sumTarget_{0.0};
int32_t epStateValue_;
bool isShardExpert_{false};
};
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales,
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
tpipe_ = pipe;
coreIdx_ = GetBlockIdx();
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
epWinContext_ = (__gm__ HcclOpResParam *)contextGM0;
GlobalTensor<int32_t> selfDataStatusTensor;
GM_ADDR statusDataSpaceGm = (GM_ADDR)epWinContext_->localWindowsExp;
selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
__asm__ __volatile__("");
dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN);
if (dataState_ == 0) {
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1;
} else {
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
__asm__ __volatile__("");
pipe_barrier(PIPE_ALL);
workspaceGM_ = workspaceGM;
expandXGM_.SetGlobalBuffer((__gm__ ExpandXType *)expandX);
expertIdsGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds);
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
aivNum_ = get_block_num();
} else {
aivNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum;
}
ubSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalUbSize;
sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
moeExpertPerRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
axisMaxBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_;
moeSendNum_ = epWorldSize_ * moeExpertPerRankNum_;
tpWorldSize_ = 1;
tpRankId_ = 0;
totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize;
stateOffset_ = (moeSendNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET;
expertPerSizeOnWin_ =
static_cast<uint64_t>(axisMaxBs_) * static_cast<uint64_t>(axisH_) * static_cast<uint64_t>(sizeof(ExpandXType));
winDataSizeOffset_ = static_cast<uint64_t>(dataState_) * static_cast<uint64_t>(moeSendNum_) * expertPerSizeOnWin_;
epWindowGM_ = GetWinAddrByRankId(epRankId_, EP_DOMAIN);
epStatusSpaceGm_ = GetWinStateAddrByRankId(epRankId_, EP_DOMAIN);
epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)epStatusSpaceGm_);
epDataOffsetOnWin_ = epRankId_ * moeExpertPerRankNum_ * static_cast<uint32_t>(expertPerSizeOnWin_);
epStateOffsetOnWin_ = epRankId_ * stateOffset_;
isShardExpert_ = (epRankId_ < sharedExpertRankNum_);
axisHFloatSize_ = axisH_ * sizeof(float);
axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType);
bsKNum_ = axisBS_ * axisK_;
if constexpr (IsNeedReduceScatter) {
tpSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)tpSendCount);
tpWindowGM_ = GetWinAddrByRankId(tpRankId_, TP_DOMAIN);
tpStatusSpaceGm_ = GetWinStateAddrByRankId(tpRankId_, TP_DOMAIN);
tpStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)tpStatusSpaceGm_);
tpDataOffsetOnWin_ = tpRankId_ * TP_RANK_SIZE_ON_WIN;
tpStateOffsetOnWin_ = tpRankId_ * stateOffset_;
uint32_t tpScatterRankWinOffset = (tpRankId_ == 0) ? TP_RANK_SIZE_ON_WIN : 0;
GM_ADDR rankGM = tpWindowGM_ + tpScatterRankWinOffset;
tpRankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
}
InitStatusTargetSum();
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
coreIdx_ = get_block_idx();
}
SplitCoreCal();
calcInfo_.epRankId_ = epRankId_;
calcInfo_.epWorldSize_ = epWorldSize_;
calcInfo_.expertPerSizeOnWin_ = expertPerSizeOnWin_;
calcInfo_.moeExpertPerRankNum_ = moeExpertPerRankNum_;
calcInfo_.sharedExpertRankNum_ = sharedExpertRankNum_;
calcInfo_.axisH_ = axisH_;
calcInfo_.moeSendNum_ = moeSendNum_;
calcInfo_.isShardExpert_ = isShardExpert_;
calcInfo_.epSendCount_ = epSendCount;
calcInfo_.epWinContext_ = epWinContext_;
calcInfo_.winDataSizeOffset_ = winDataSizeOffset_;
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::InitStatusTargetSum()
{
// ep state
GlobalTensor<int32_t> selfStatusTensor;
selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(epStatusSpaceGm_ + SELF_STATE_OFFSET));
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
selfStatusTensor[coreIdx_ * UB_ALIGN]);
__asm__ __volatile__("");
int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN);
if (state == 0) {
sumTarget_ = static_cast<float>(1.0);
selfStatusTensor(coreIdx_ * UB_ALIGN) = 0x3F800000; // 1.0f
epStateValue_ = 0x3F800000; // 1.0f
} else {
sumTarget_ = static_cast<float>(0.0);
selfStatusTensor(coreIdx_ * UB_ALIGN) = 0;
epStateValue_ = 0;
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
selfStatusTensor[coreIdx_ * UB_ALIGN]);
__asm__ __volatile__("");
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::BuffInit()
{
tpipe_->Reset();
tpipe_->InitBuffer(readStateBuf_, UB_ALIGN);
uint32_t sendNumAlign = Ceil(moeSendNum_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(sendCountBuf_, sendNumAlign);
if constexpr (IsNeedReduceScatter) {
tpipe_->InitBuffer(winTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
tpipe_->InitBuffer(gmTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
if constexpr (AscendC::IsSameType<ExpandXType, bfloat16_t>::value) {
tpipe_->InitBuffer(winTpSendCountFloatBuf_, axisHFloatSize_);
tpipe_->InitBuffer(gmTpSendCountFloatBuf_, axisHFloatSize_);
winTpSendCountFloatTensor_ = winTpSendCountFloatBuf_.Get<float>();
gmTpSendCountFloatTensor_ = gmTpSendCountFloatBuf_.Get<float>();
}
} else {
tpipe_->InitBuffer(gmTpSendCountQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
}
epSendCountLocal_ = sendCountBuf_.Get<int32_t>();
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuffInit()
{
tpipe_->Reset();
uint32_t bsMulTopkSizeAligned = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(readStateBuf_, UB_ALIGN);
tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN);
tpipe_->InitBuffer(expertIdsBuf_, bsMulTopkSizeAligned);
tpipe_->InitBuffer(expandScalesBuf_, bsMulTopkSizeAligned);
tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType));
tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_);
tpipe_->InitBuffer(mulBuf_, axisHFloatSize_);
tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_);
tpipe_->InitBuffer(indexCountsBuf_, bsMulTopkSizeAligned);
tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SplitCoreCal()
{
sendRankNum_ = epWorldSize_ / aivNum_;
uint32_t remainderRankNum = epWorldSize_ % aivNum_;
startRankId_ = sendRankNum_ * coreIdx_;
if (coreIdx_ < remainderRankNum) {
sendRankNum_++;
startRankId_ += coreIdx_;
} else {
startRankId_ += remainderRankNum;
}
endRankId_ = startRankId_ + sendRankNum_;
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ReduceScatterTrans()
{
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(tpSendCountGM_[tpRankId_]);
__asm__ __volatile__("");
uint32_t offset = tpSendCountGM_.GetValue(tpRankId_) * axisH_;
GlobalTensor<ExpandXType> dataCopyInGM = expandXGM_[offset];
GM_ADDR rankGM = GetWinAddrByRankId(1 - tpRankId_, TP_DOMAIN) + tpDataOffsetOnWin_;
rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
uint32_t copyStartIdx = 0;
if (startRankId_ > 0) {
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
epSendCountGM_[epWorldSize_ + startRankId_ - 1]);
__asm__ __volatile__("");
copyStartIdx = epSendCountGM_.GetValue(epWorldSize_ + startRankId_ - 1);
}
__asm__ __volatile__("");
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
epSendCountGM_[epWorldSize_ + endRankId_ - 1]);
__asm__ __volatile__("");
uint32_t copyEndIdx = epSendCountGM_.GetValue(epWorldSize_ + endRankId_ - 1);
LocalTensor<ExpandXType> tmpUb;
for (uint32_t tokenNumIdx = copyStartIdx; tokenNumIdx < copyEndIdx; tokenNumIdx++) {
tmpUb = moeQueue_.AllocTensor<ExpandXType>();
DataCopy(tmpUb, dataCopyInGM[tokenNumIdx * axisH_], axisH_);
moeQueue_.EnQue(tmpUb);
tmpUb = moeQueue_.DeQue<ExpandXType>();
DataCopy(rankWindow_[tokenNumIdx * axisH_], tmpUb, axisH_);
moeQueue_.FreeTensor<ExpandXType>(tmpUb);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SetWaitTpStatusAndDisPatch()
{
pipe_barrier(PIPE_ALL);
if (startRankId_ >= epWorldSize_) {
return;
}
if constexpr (IsNeedReduceScatter) {
uint32_t tpToRankId = 1 - tpRankId_;
pipe_barrier(PIPE_ALL);
LocalTensor<float> statusFlagUb = readStateBuf_.Get<float>();
statusFlagUb(0) = sumTarget_;
SyncFunc<AscendC::HardEvent::S_MTE3>();
GlobalTensor<float> tpWindowInstatusFp32Tensor_;
stateGM_ = GetWinStateAddrByRankId(tpToRankId, TP_DOMAIN) + coreIdx_ * stateOffset_;
tpWindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)stateGM_);
DataCopy<float>(tpWindowInstatusFp32Tensor_, statusFlagUb, 8UL);
SyncFunc<AscendC::HardEvent::MTE3_S>();
LocalTensor<float> statusFp32Tensor_ = readStateBuf_.Get<float>();
float sumOfFlag = static_cast<float>(-1.0);
uint32_t statusRankOffset = coreIdx_ * stateOffset_ / sizeof(float);
while (sumOfFlag != sumTarget_) {
DataCopy<float>(statusFp32Tensor_, tpStatusSpaceGlobalTensor_[statusRankOffset], 8);
SyncFunc<AscendC::HardEvent::MTE2_S>();
sumOfFlag = statusFp32Tensor_.GetValue(0);
SyncFunc<AscendC::HardEvent::S_MTE2>();
}
}
ExpertAlltoAllDispatchCopyAdd();
SyncFunc<AscendC::HardEvent::MTE3_S>();
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ExpertAlltoAllDispatchCopyAdd()
{
if (startRankId_ >= epWorldSize_) {
return;
}
uint32_t curRankExpertNum = 0;
DataCopyExtParams epSendCntParams;
if (isShardExpert_) {
curRankExpertNum = 1;
epSendCntParams = {1U, static_cast<uint32_t>(epWorldSize_ * sizeof(uint32_t)), 0U, 0U, 0U};
} else {
curRankExpertNum = moeExpertPerRankNum_;
epSendCntParams = {1U, static_cast<uint32_t>(moeSendNum_ * sizeof(uint32_t)), 0U, 0U, 0U};
}
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
DataCopyPad(epSendCountLocal_, epSendCountGM_, epSendCntParams, copyPadParams);
SyncFunc<AscendC::HardEvent::MTE2_S>();
uint32_t preCount = 0;
uint32_t startTokenIdx = 0;
uint32_t curTokenNum = 0;
for (uint32_t expertIdx = 0U; expertIdx < curRankExpertNum; expertIdx++) {
uint32_t sendEpCount = endRankId_ - startRankId_;
for (uint32_t i = 0; i < sendEpCount; ++i) {
uint32_t ep = startRankId_ + (i + epRankId_) % sendEpCount;
if ((ep > 0) || (expertIdx > 0U)) {
preCount = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep - 1);
} else {
preCount = 0;
}
curTokenNum = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep) - preCount;
if (curTokenNum == 0) {
continue;
}
startTokenIdx = preCount * axisH_;
ExpertAlltoAllDispatchInnerCopyAdd(curTokenNum, startTokenIdx, ep, expertIdx);
}
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ExpertAlltoAllDispatchInnerCopyAdd(
uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, uint32_t ep, uint32_t expertIdx)
{
GM_ADDR rankGM = GetWinAddrByRankId(ep, EP_DOMAIN, expertIdx) + epDataOffsetOnWin_;
if ((isShardExpert_) && (ep < sharedExpertRankNum_)) {
rankGM = GetWinAddrByRankId(epRankId_, EP_DOMAIN, expertIdx) + ep * moeExpertPerRankNum_ * expertPerSizeOnWin_;
}
rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM);
uint32_t dataCnt = axisH_;
for (uint32_t loopIdx = 0; loopIdx < tokenNumLoop; loopIdx++) {
if constexpr (IsNeedReduceScatter) {
gmTpSendCountTensor_ = gmTpSendCountInQueue_.AllocTensor<ExpandXType>();
DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt);
gmTpSendCountInQueue_.EnQue(gmTpSendCountTensor_);
winTpSendCountTensor_ = winTpSendCountInQueue_.AllocTensor<ExpandXType>();
DataCopy(winTpSendCountTensor_, tpRankWindow_[srcStartTokenIdx], dataCnt);
winTpSendCountInQueue_.EnQue(winTpSendCountTensor_);
gmTpSendCountTensor_ = gmTpSendCountInQueue_.DeQue<ExpandXType>();
winTpSendCountTensor_ = winTpSendCountInQueue_.DeQue<ExpandXType>();
outTensor_ = xOutQueue_.AllocTensor<ExpandXType>();
CustomAdd(outTensor_, winTpSendCountTensor_, gmTpSendCountTensor_, dataCnt);
gmTpSendCountInQueue_.FreeTensor<ExpandXType>(gmTpSendCountTensor_);
winTpSendCountInQueue_.FreeTensor<ExpandXType>(winTpSendCountTensor_);
xOutQueue_.EnQue(outTensor_);
outTensor_ = xOutQueue_.DeQue<ExpandXType>();
DataCopy(rankWindow_[loopIdx * dataCnt], outTensor_, dataCnt);
xOutQueue_.FreeTensor<ExpandXType>(outTensor_);
} else {
gmTpSendCountTensor_ = gmTpSendCountQueue_.AllocTensor<ExpandXType>();
DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt);
ExpandXType val = expandXGM_[srcStartTokenIdx].GetValue(0);
gmTpSendCountQueue_.EnQue(gmTpSendCountTensor_);
gmTpSendCountTensor_ = gmTpSendCountQueue_.DeQue<ExpandXType>();
DataCopy(rankWindow_[loopIdx * dataCnt], gmTpSendCountTensor_, dataCnt);
gmTpSendCountQueue_.FreeTensor<ExpandXType>(gmTpSendCountTensor_);
}
srcStartTokenIdx += dataCnt;
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::CustomAdd(LocalTensor<ExpandXType> &dst,
LocalTensor<ExpandXType> &src0,
LocalTensor<ExpandXType> &src1,
uint32_t dataCnt)
{
if constexpr (AscendC::IsSameType<ExpandXType, bfloat16_t>::value) {
Cast(winTpSendCountFloatTensor_, src0, RoundMode::CAST_NONE, dataCnt);
Cast(gmTpSendCountFloatTensor_, src1, RoundMode::CAST_NONE, dataCnt);
pipe_barrier(PIPE_V);
Add(winTpSendCountFloatTensor_, winTpSendCountFloatTensor_, gmTpSendCountFloatTensor_, dataCnt);
pipe_barrier(PIPE_V);
Cast(dst, winTpSendCountFloatTensor_, RoundMode::CAST_ROUND, dataCnt);
} else {
Add(dst, src0, src1, dataCnt);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::SetStatus()
{
pipe_barrier(PIPE_ALL);
if (startRankId_ >= epWorldSize_) {
return;
}
LocalTensor<int32_t> statusFlagUb = readStateBuf_.Get<int32_t>();
statusFlagUb.SetValue(0, epStateValue_);
SyncFunc<AscendC::HardEvent::S_MTE3>();
for (uint32_t epIdx = startRankId_; epIdx < endRankId_; epIdx++) {
stateGM_ = GetWinStateAddrByRankId(epIdx, EP_DOMAIN) + epStateOffsetOnWin_;
rankStates_.SetGlobalBuffer((__gm__ int32_t *)stateGM_);
DataCopy(rankStates_, statusFlagUb, 8);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatch()
{
if (startRankId_ < epWorldSize_) {
LocalTensor<float> statusTensor = statusBuf_.Get<float>();
LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf_.Get<float>();
LocalTensor<uint32_t> gatherTmpTensor = gatherTmpBuf_.Get<uint32_t>();
LocalTensor<float> statusSumOutTensor = statusSumOutBuf_.Get<float>();
PipeBarrier<PIPE_ALL>();
gatherTmpTensor.SetValue(0, 1);
uint32_t mask = 1; // gatherMask + sum
uint64_t rsvdCnt = 0;
DataCopyParams intriParams{static_cast<uint16_t>(sendRankNum_), 1,
static_cast<uint16_t>((moeSendNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5;
float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5;
SumParams sumParams{1, sendRankNum_, sendRankNum_};
SyncFunc<AscendC::HardEvent::S_V>();
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
DataCopy<float>(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)],
intriParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask,
{1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
SyncFunc<AscendC::HardEvent::V_S>();
sumOfFlag = statusSumOutTensor.GetValue(0);
}
}
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
} else {
SyncAll<true>();
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
{
uint32_t beginIndex = 0;
uint32_t endIndex = 0;
uint32_t processLen = 0;
uint32_t tokenOffset = 0;
if (axisBS_ < aivNum_) {
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) {
return;
}
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
processLen = ((axisH_ / UB_ALIGN) / aivNumPerToken) * UB_ALIGN;
tokenOffset = processLen * (coreIdx_ % aivNumPerToken);
if ((coreIdx_ % aivNumPerToken) == (aivNumPerToken - 1)) {
processLen = axisH_ - ((aivNumPerToken - 1) * processLen);
}
beginIndex = tokenIndex;
endIndex = beginIndex + 1U;
} else {
uint32_t tokenPerAivNum = axisBS_ / aivNum_;
uint32_t remainderToken = axisBS_ % aivNum_;
beginIndex = tokenPerAivNum * coreIdx_;
if (coreIdx_ < remainderToken) {
tokenPerAivNum++;
beginIndex = tokenPerAivNum * coreIdx_;
} else {
beginIndex += remainderToken;
}
endIndex = beginIndex + tokenPerAivNum;
processLen = axisH_;
}
LocalTensor<ExpandIdxType> expertIdsLocal = expertIdsBuf_.Get<ExpandIdxType>();
LocalTensor<float> expandScalesLocal = expandScalesBuf_.Get<float>();
LocalTensor<float> rowTmpFloatLocal = rowTmpFloatBuf_.Get<float>();
LocalTensor<float> mulBufLocal = mulBuf_.Get<float>();
LocalTensor<float> sumFloatBufLocal = sumFloatBuf_.Get<float>();
LocalTensor<ExpandIdxType> indexCountsLocal = indexCountsBuf_.Get<ExpandIdxType>();
const DataCopyExtParams bskParams = {1U, static_cast<uint32_t>(bsKNum_ * sizeof(uint32_t)), 0U, 0U, 0U};
const DataCopyPadExtParams<ExpandIdxType> copyPadParams{false, 0U, 0U, 0U};
const DataCopyPadExtParams<float> copyPadFloatParams{false, 0U, 0U, 0U};
DataCopyPad(indexCountsLocal, expandIdxGM_, bskParams, copyPadParams);
DataCopyPad(expertIdsLocal, expertIdsGM_, bskParams, copyPadParams);
DataCopyPad(expandScalesLocal, expandScalesGM_, bskParams, copyPadFloatParams);
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t tokenIndex = beginIndex; tokenIndex < endIndex; tokenIndex++) {
uint32_t index = tokenIndex * axisK_;
SyncFunc<AscendC::HardEvent::MTE3_V>();
Duplicate(sumFloatBufLocal, (float)0, axisH_);
for (uint32_t i = 0; i < axisK_; i++) {
int32_t moeExpert = expertIdsLocal.GetValue(index);
if (moeExpert < 0) {
index++;
continue;
}
float scaleVal = expandScalesLocal.GetValue(index);
GM_ADDR wAddr = (__gm__ uint8_t *)(epWindowGM_) +
expertPerSizeOnWin_ * moeExpertPerRankNum_ * sharedExpertRankNum_ +
expertPerSizeOnWin_ * moeExpert + indexCountsLocal.GetValue(index) * axisHExpandXTypeSize_ +
tokenOffset * sizeof(ExpandXType);
rowTmpGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr);
ExpandXType val = rowTmpGlobal_.GetValue(0);
LocalTensor<ExpandXType> tmpUb = moeSumQueue_.AllocTensor<ExpandXType>();
DataCopy(tmpUb, rowTmpGlobal_, processLen);
moeSumQueue_.EnQue(tmpUb);
tmpUb = moeSumQueue_.DeQue<ExpandXType>();
Cast(rowTmpFloatLocal, tmpUb, AscendC::RoundMode::CAST_NONE, processLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Muls(mulBufLocal, rowTmpFloatLocal, scaleVal, processLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, mulBufLocal, processLen);
index++;
moeSumQueue_.FreeTensor<ExpandXType>(tmpUb);
}
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
if (sharedExpertRankNum_ > 0U) {
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
epRankId_ * axisBS_ / sharedExpertRankNum_;
__gm__ ExpandXType *shareAddr =
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
(tokenIndex - preCnt) * axisH_ + tokenOffset;
GlobalTensor<ExpandXType> shareTokGlobal;
shareTokGlobal.SetGlobalBuffer((__gm__ ExpandXType *)(shareAddr));
SyncFunc<AscendC::HardEvent::V_MTE2>();
DataCopy(rowTmpLocal, shareTokGlobal, processLen);
SyncFunc<AscendC::HardEvent::MTE2_V>();
Cast(rowTmpFloatLocal, rowTmpLocal, AscendC::RoundMode::CAST_NONE, processLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, rowTmpFloatLocal, processLen);
}
AscendC::PipeBarrier<PIPE_V>();
LocalTensor<ExpandXType> sumBufLocal = tokenBuf_.Get<ExpandXType>();
Cast(sumBufLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, processLen);
SyncFunc<AscendC::HardEvent::V_MTE3>();
DataCopy(expandOutGlobal_[tokenIndex * axisH_ + tokenOffset], sumBufLocal, processLen);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Process()
{
SyncAll<true>();
if constexpr (IsNeedReduceScatter) {
tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
ReduceScatterTrans();
}
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
BuffInit();
SetWaitTpStatusAndDisPatch();
}
AlltoAllBuffInit();
SetStatus();
WaitDispatch();
LocalWindowCopy();
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AllToAllSend()
{
if constexpr (IsNeedReduceScatter) {
tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
ReduceScatterTrans();
}
BuffInit();
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
SetWaitTpStatusAndDisPatch();
AlltoAllBuffInit();
}
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
} else {
SyncAll<true>();
}
SetStatus();
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID);
} else {
SyncAll<true>();
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::ReducePermute()
{
AlltoAllBuffInit();
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID);
} else {
SyncAll<true>();
}
WaitDispatch();
LocalWindowCopy();
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID);
}
}
} // namespace MoeDistributeCombineImpl
#endif // CAM_MOE_DISTRIBUTE_COMBINE_IMPL_H

View File

@@ -0,0 +1,18 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H
#define DISPATCH_GMM_COMBINE_DECODE_BASE_H
#include "moe_distribute_base.h"
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H

View File

@@ -0,0 +1,74 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE_TILING_H
#define DISPATCH_GMM_COMBINE_DECODE_TILING_H
#include "kernel_tiling/kernel_tiling.h"
struct DispatchGmmCombineDecodeInfo {
uint32_t epRankSize; // epRankSize
uint32_t epRankId; // epRankId
uint32_t moeExpertNum; // moe expert number
uint32_t moeExpertNumPerRank; // moe expert number per rank
uint32_t sharedExpertNum; // shared expert number
uint32_t sharedExpertRankNum; // shared expert rank number
uint32_t quantMode; // quant mode
uint32_t globalBs; // globalBs = BS * worldSize
uint32_t bs; // bs
uint32_t k; // k
uint32_t h; // h
uint32_t aicNum; // aicNum
uint32_t aivNum; // aivNum
uint64_t totalUbSize;
uint64_t totalWinSize;
uint64_t gmm1HLen;
};
struct DispatchGmmCombineDecodeTilingData {
Mc2InitTiling mc2InitTiling;
Mc2CcTiling mc2CcTiling;
DispatchGmmCombineDecodeInfo disGmmDeqSwigluQuantGmmDeqComInfo;
};
constexpr uint32_t GM_ALIGN_BYTE = 512;
constexpr uint32_t CUSTOM_PRELOAD_STAGES = 1;
constexpr uint32_t CUSTOM_L1_STAGES = 2;
constexpr uint32_t CUSTOM_L0A_STAGES = 2;
constexpr uint32_t CUSTOM_L0B_STAGES = 2;
constexpr uint32_t CUSTOM_L0C_STAGES = 1;
constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true;
constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true;
constexpr uint32_t GMM1_L1M = 256;
constexpr uint32_t GMM1_L1N = 128;
constexpr uint32_t GMM1_L1K = 512;
constexpr uint32_t GMM1_L0K = 128;
constexpr uint32_t GMM1_EPIM = 64;
constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3;
constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0;
constexpr uint32_t GMM2_L1A_STAGES = 4;
constexpr uint32_t GMM2_L1B_STAGES = 2;
constexpr uint32_t GMM2_L0A_STAGES = 4;
constexpr uint32_t GMM2_L0B_STAGES = 2;
constexpr uint32_t GMM2_L1M = 128;
constexpr uint32_t GMM2_L1N = 256;
constexpr uint32_t GMM2_L1K = 512;
constexpr uint32_t GMM2_L0K = 128;
constexpr uint32_t GMM2_EPIM = 32;
constexpr uint32_t GMM2_SWIZZLE_OFFSET = 3;
constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
constexpr uint32_t WORKSPACE_STAGES = 4;
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H

View File

@@ -595,6 +595,64 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
const at::Tensor &x,
const at::Tensor &expert_ids,
const at::Tensor &gmm1_permuted_weight,
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales,
c10::string_view group_ep,
int64_t ep_rank_size,
int64_t ep_rank_id,
int64_t moe_expert_num,
int64_t shared_expert_num,
int64_t shared_expert_rank_num,
int64_t quant_mode,
int64_t global_bs)
{
auto x_shape = x.sizes();
int bs = x_shape[0];
int h = x_shape[1];
at::Tensor output = at::empty({bs, h}, x.options());
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
group_ep_chrs.push_back('\0');
char *group_ep_ptr = &group_ep_chrs[0];
EXEC_NPU_CMD(
// op api
aclnnDispatchGmmCombineDecode,
// input tensors
x,
expert_ids,
gmm1_permuted_weight,
gmm1_permuted_weight_scale,
gmm2_weight,
gmm2_weight_scale,
expert_smooth_scales,
expert_scales,
//input attrs
group_ep_ptr,
ep_rank_size,
ep_rank_id,
moe_expert_num,
shared_expert_num,
shared_expert_rank_num,
quant_mode,
global_bs,
// output tensors
output,
ep_recv_count);
return {output, ep_recv_count};
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
@@ -818,6 +876,19 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)");
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
ops.def(
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
" Tensor gmm1_permuted_weight_scale,"
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
" str group_ep='',"
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
" int shared_expert_num=1, int shared_expert_rank_num=0,"
" int quant_mode=0,"
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
);
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);
ops.def(
"grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale,"
" Tensor group_list, *,"

View File

@@ -154,6 +154,37 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
const at::Tensor &x,
const at::Tensor &expert_ids,
const at::Tensor &gmm1_permuted_weight,
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales,
c10::string_view group_ep,
int64_t ep_rank_size,
int64_t ep_rank_id,
int64_t moe_expert_num,
int64_t shared_expert_num,
int64_t shared_expert_rank_num,
int64_t quant_mode,
int64_t global_bs)
{
auto x_shape = x.sizes();
int bs = x_shape[0];
int h = x_shape[1];
at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta));
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta));
return {output, ep_recv_count};
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
@@ -255,6 +286,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant);
// Grouped matmul swiglu quant weight nz tensor list
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta);
// dispatch_gmm_combine_decode meta implementation
ops.impl("dispatch_gmm_combine_decode", &vllm_ascend::meta::dispatch_gmm_combine_decode_meta);
// batch_matmul_transpose
ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose);
// Lightning indexer

View File

@@ -163,6 +163,7 @@ cd ..
```
vllm-ascend will build custom operators by default. If you don't want to build it, set `COMPILE_CUSTOM_KERNELS=0` environment to disable it.
If you are building custom operators for Atlas A3, you should run `git submodule update --init --recursive` manually, or ensure your environment has Internet access.
:::
```{note}

View File

@@ -0,0 +1,436 @@
import gc
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
import torchair
from vllm_ascend.utils import enable_custom_op
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
torch_npu.npu.config.allow_internal_format = True
enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
def redirect_output(log_file_path):
log_path = Path(LOG_NAME) / log_file_path
log_path.parent.mkdir(parents=True, exist_ok=True)
f = open(LOG_NAME + "/" + log_file_path, "w")
os.dup2(f.fileno(), sys.stdout.fileno())
os.dup2(f.fileno(), sys.stderr.fileno())
return f
def permute_weight(w: torch.Tensor, tile_n):
*dims, n = w.shape
order = list(range(len(dims))) + [-2, -3, -1]
return w.reshape(*dims, 2, n // tile_n,
tile_n // 2).permute(order).reshape(*dims,
n).contiguous()
def from_inclusive_prefix_sum(pref):
if isinstance(pref, torch.Tensor):
if pref.numel() == 0:
return pref
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
if not pref:
return []
out = [pref[0]]
for i in range(1, len(pref)):
out.append(pref[i] - pref[i - 1])
return out
def output_to_file(rank_id):
return False
class DecodeMoeOps(torch.nn.Module):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__()
self.ep_hcomm_info = ep_hcomm_info
self.batch_size = batch_size
self.token_hidden_size = token_hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.ep_world_size = ep_world_size
self.moe_expert_num = moe_expert_num
self.global_rank_id = global_rank_id
self.shared_expert_rank_num = shared_expert_rank_num
is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num)
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
self.ep_recv_count_size = self.local_expert_num * ep_world_size
self.gmm1_weight = torch.empty([
self.local_expert_num, self.token_hidden_size,
self.moe_intermediate_size * 2
])
self.gmm1_weight_scale = torch.empty(
[self.local_expert_num, self.moe_intermediate_size * 2])
self.gmm2_weight = torch.empty([
self.local_expert_num, self.moe_intermediate_size,
self.token_hidden_size
])
self.gmm2_weight_scale = torch.empty(
[self.local_expert_num, self.token_hidden_size])
self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale)
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
raise NotImplementedError("To be implemented in subclass")
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
raise NotImplementedError("To be implemented in subclass")
def forward(self, x, expert_ids, smooth_scales, expert_scales):
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
class SmallOps(DecodeMoeOps):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
self.tp_hcomm_info = ""
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x,
expert_ids=expert_ids,
expert_scales=expert_scales,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
group_tp=self.tp_hcomm_info,
tp_world_size=1,
tp_rank_id=0,
expert_shard_type=0,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=2,
global_bs=self.batch_size * self.ep_world_size,
expert_token_nums_type=1, # 0代表前缀和1代表各自数量
)
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
output_dtype = x.dtype
y1_int32 = torch_npu.npu_grouped_matmul(
x=[expand_x],
weight=[self.gmm1_weight],
split_item=3,
group_list_type=1, # 默认为0代表前缀和形式
group_type=0, # 0代表m轴分组
group_list=expert_token_nums,
output_dtype=torch.int32)[0]
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
x=y1_int32,
weight_scale=self.gmm1_weight_scale.to(torch.float32),
activation_scale=dynamic_scales,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_token_nums,
activate_left=True,
quant_mode=1,
)
y2 = torch_npu.npu_grouped_matmul(x=[y1],
weight=[self.gmm2_weight],
scale=[self.gmm2_weight_scale],
per_token_scale=[y1_scale],
split_item=2,
group_list_type=1,
group_type=0,
group_list=expert_token_nums,
output_dtype=output_dtype)[0]
combine_output = torch_npu.npu_moe_distribute_combine_v2(
expand_x=y2,
expert_ids=expert_ids,
assist_info_for_combine=assist_info_for_combine,
ep_send_counts=ep_send_counts,
expert_scales=expert_scales,
group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
tp_send_counts=tp_send_counts,
expand_scales=expand_scales,
group_tp=self.tp_hcomm_info,
tp_world_size=1,
tp_rank_id=0,
expert_shard_type=0,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
global_bs=self.batch_size * self.ep_world_size)
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
class FusionOp(DecodeMoeOps):
def __init__(self,
gmm1_weight,
gmm1_weight_scale,
gmm2_weight,
gmm2_weight_scale,
ep_hcomm_info,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
.transpose(1,2).contiguous()\
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
.transpose(1,2).contiguous()
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.ND)
gmm1_weight.add_(0)
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
gmm2_weight = torch_npu.npu_format_cast(
gmm2_weight.transpose(1, 2).contiguous(),
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x,
expert_ids=expert_ids,
gmm1_permuted_weight=self.gmm1_weight,
gmm1_permuted_weight_scale=self.gmm1_weight_scale,
gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale,
expert_smooth_scales=smooth_scales,
expert_scales=expert_scales,
group_ep=self.ep_hcomm_info,
ep_rank_size=self.ep_world_size,
ep_rank_id=self.global_rank_id,
moe_expert_num=self.moe_expert_num,
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=0,
global_bs=self.batch_size * self.ep_world_size)
return output
def generate_datas(batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False):
is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num)
actual_bs = int(
torch.randint(1, batch_size, [1]).item(
) if enable_dynamic_bs else batch_size)
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
gmm1_input_dim = token_hidden_size
gmm1_output_dim = moe_intermediate_size * 2
gmm2_input_dim = moe_intermediate_size
gmm2_output_dim = token_hidden_size
x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5
expert_ids = torch.arange(
global_rank_id * batch_size * top_k,
global_rank_id * batch_size * top_k + actual_bs * top_k).to(
torch.int32).view(actual_bs, top_k)
expert_ids = expert_ids % moe_expert_num
if is_shared_expert:
gmm1_weight = torch.ones([
local_expert_num, gmm1_input_dim, gmm1_output_dim
]).to(torch.int8) * 4
gmm2_weight = torch.ones([
local_expert_num, gmm2_input_dim, gmm2_output_dim
]).to(torch.int8) * 4
gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1
gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1
gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim
]) * 0.0015
gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim
]) * 0.0015
else:
gmm1_weight = torch.randint(
-16, 16,
[local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8)
gmm2_weight = torch.randint(
-16, 16,
[local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8)
gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim
]) * 0.003 + 0.0015
gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim
]) * 0.003 + 0.0015
expert_scales = torch.rand(actual_bs, top_k)
if test_bfloat16:
x = x.bfloat16()
gmm1_weight_scale = gmm1_weight_scale.bfloat16()
gmm2_weight_scale = gmm2_weight_scale.bfloat16()
else:
x = x.half()
smooth_sales = None
return (x, expert_ids, smooth_sales, expert_scales), \
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
actual_bs
def run_once(local_rank_id,
batch_size,
token_hidden_size,
moe_intermediate_size,
ep_world_size,
moe_expert_num,
shared_expert_rank_num=0,
top_k=8,
test_bfloat16=True,
enable_dynamic_bs=False,
test_graph=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
device_id = local_rank_id % 16
torch_npu.npu.set_device(device_id)
# 初始化分布式环境
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500" # 端口号随意
dist.init_process_group(backend="hccl",
rank=local_rank_id,
world_size=ep_world_size)
ep_ranks_list = list(np.arange(0, ep_world_size))
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list)
ep_hcomm_info_fused = ep_group._get_backend(
torch.device("npu")).get_hccl_comm_name(local_rank_id)
ep_hcomm_info_small = ep_group_small._get_backend(
torch.device("npu")).get_hccl_comm_name(local_rank_id)
torch_npu.npu.synchronize(device_id)
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
input_datas, weight_datas, actual_bs = generate_datas(
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
input_datas = [
data.npu() if data is not None else None for data in input_datas
]
weight_datas = [
data.npu() if data is not None else None for data in weight_datas
]
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
*parameter).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore
if test_graph:
fused_ops = torch.compile(fused_ops, backend=npu_backend)
small_op_token_output, small_op_count_output = small_ops(*input_datas)
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
torch_npu.npu.synchronize(device_id)
dist.destroy_process_group()
if log_file is not None:
log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output.cpu(),
fused_op_token_output.cpu(),
atol=2.0,
rtol=0.02)
torch.testing.assert_close(small_op_count_output.cpu(),
fused_op_count_output.cpu())
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@torch.inference_mode()
def test():
batch_size = 64
token_hidden_size = 7168
moe_intermediate_size = 2048
ep_world_size = 16
moe_expert_num = 64
shared_expert_rank_num = 0
top_k = 8
test_bfloat16 = True
enable_dynamic_bs = False
test_graph = False
args = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
test_bfloat16, enable_dynamic_bs, test_graph)
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)