add dispatch_gmm_combine kernel (#3532)

### What this PR does / why we need it?

This PR introduces the Ascend implementation of the
`dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime,
together with follow‑up fixes to ensure the kernel builds and runs
correctly in CI.

- Add full host and device implementation of the `dispatch_ffn_combine`
kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE
routing helpers, and kernel utilities for quantized FFN dispatch.
- Integrate the new kernel with the PyTorch binding
(csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend
runtime (vllm_ascend/ascend_forward_context.py,
vllm_ascend/worker/model_runner_v1.py).
- Extend fused MoE communication and token dispatch support in
`vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new
dispatch path.
- Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py
to support the new FFN dispatch flow.
- Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake
configuration, and include/namespace usage in the new kernel files.
- Add an end‑to‑end nightly test
`tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper
utilities in `vllm_ascend/utils.py` to validate the new kernel.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

---------

Signed-off-by: mojave2 <chenchen145@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Chen Chen
2025-12-04 23:00:59 +08:00
committed by GitHub
parent 752a55473c
commit ad0607f900
61 changed files with 9795 additions and 53 deletions

4
.gitmodules vendored Normal file
View File

@@ -0,0 +1,4 @@
[submodule "csrc/third_party/catlass"]
path = csrc/third_party/catlass
url = https://gitcode.com/cann/catlass.git
branch = catlass-v1-stable

View File

@@ -3,6 +3,8 @@
ROOT_DIR=$1
SOC_VERSION=$2
git config --global --add safe.directory "$ROOT_DIR"
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
# ASCEND310P series
# currently, no custom aclnn ops for ASCEND310 series
@@ -11,11 +13,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
exit 0
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
# ASCEND910B (A2) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
SOC_ARG="ascend910_93"
else
# others
@@ -23,6 +25,30 @@ else
exit 0
fi
git submodule init
git submodule update
# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h
file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1)
if [ -z "$file_path" ]; then
echo "cannot find moe_distribute_base.h file in CANN env"
exit 1
fi
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
TARGET_FILE="$TARGET_DIR/$(basename "$file_path")"
echo "*************************************"
echo $file_path
echo "$TARGET_DIR"
cp "$file_path" "$TARGET_DIR"
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
# build custom ops
cd csrc
rm -rf build output

View File

@@ -282,7 +282,7 @@ function(add_ops_src_copy)
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
add_custom_command(OUTPUT ${_BUILD_FLAG}
COMMAND mkdir -p ${SRC_COPY_DST}
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST}
COMMAND touch ${_BUILD_FLAG}
)

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
set(_DISPATCH_FFN_INC_OPTS)
if (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
endif()
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
endif()
add_ops_compile_options(
OP_NAME DispatchFFNCombine
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
-DHCCL_COMM
${_DISPATCH_FFN_INC_OPTS}
)
target_sources(op_host_aclnnInner PRIVATE
dispatch_ffn_combine_def.cpp
)
target_sources(opapi PRIVATE
aclnn_dispatch_ffn_combine.cpp
)
if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_dispatch_ffn_combine.cpp
)
target_sources(aclnn_ops_infer PRIVATE
aclnn_dispatch_ffn_combine.cpp
)
endif ()
target_sources(optiling PRIVATE
dispatch_ffn_combine_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel
)
target_sources(opsproto PRIVATE
dispatch_ffn_combine_proto.cpp
)
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_ffn_combine.h")
install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)

View File

@@ -0,0 +1,84 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_dispatch_ffn_combine.h"
#include <algorithm>
// #include "aclnn_kernels/common/op_error_check.h"
// #include "opdev/op_log.h"
// #include "opdev/common_types.h"
// #include "opdev/platform.h"
// #include "ophost/matmul_util.h"
#include <unistd.h>
#include <vector>
#include <string>
#include <iostream>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/file.h>
#include <climits>
#include "../op_host/error_log.h"
// using namespace op;
// using namespace op;
#ifdef __cplusplus
extern "C" {
#endif
static constexpr size_t TWO_DIMS = 2;
static constexpr int64_t KVALUE_MIN = 256;
static constexpr int64_t KVALUE_MAX = 65535;
static constexpr size_t HCCL_GROUP_NAME_MAX = 128U;
enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
bool transB, bool weightNz,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor);
extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize,
aclOpExecutor *executor, aclrtStream stream);
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor)
{
bool transB = false;
bool weightNz = true;
aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
maxOutputSize, transB, weightNz,
out, workspaceSize, executor);
return ret;
}
aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
aclnnStatus ret = aclnnInnerDispatchFFNCombine(workspace, workspaceSize, executor, stream);
return ret;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,61 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OP_API_INC_DISPATCH_FFN_COMBINE_
#define OP_API_INC_DISPATCH_FFN_COMBINE_
#include <string>
#include "aclnn/aclnn_base.h"
#include "hccl/hccl.h"
#include "hccl/hccl_types.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* 算子功能实现分布式MoE从InitRouting到Unpermute全部算子的融合
* @brief aclnnDispatchFFNCombine的第一段接口根据具体的计算流程计算workspace大小。
* @domain aclnn_ops_infer
* @param [in] a: matmul左矩阵数据类型支持float16, bf16。
* @param [in] b: matmul右矩阵数据类型支持float16, bf16。
* @param [in] bias: 偏置数据类型支持float16, bf16。
* @param [in] group: 标识通信域名称的字符串。
* @param [in] worldsize: 通信域size支持2/4/8卡。
* @param [in] epRankId: ep本卡Id。取值范围[0, worldSize)各卡的rankId不能重复
* @param [out] c: 计算+通信的结果,数据类型:同输入。
* @param [out] workspaceSize: 返回需要在npu device侧申请的workspace大小。
* @param [out] executor: 返回op执行器包含了算子计算流程。
* @return aclnnStatus: 返回状态码
*/
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor);
/**
* @brief aclnnDispatchGmmCombine的第二段接口用于执行计算。
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
* @param [in] workspace_size: 在npu device侧申请的workspace大小由第一段接口aclnnDispatchFFNCombineGetWorkspaceSize获取。
* @param [in] exector: op执行器包含了算子计算流程。
* @param [in] stream: acl stream流。
* @return aclnnStatus: 返回状态码
*/
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);
#ifdef __cplusplus
}
#endif
#endif // OP_API_INC_GMM_ALLTOALLV_

View File

@@ -0,0 +1,88 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dispatch_ffn_combine_def.cpp
* \brief
*/
#include "register/op_def_registry.h"
namespace ops {
class DispatchFFNCombine : public OpDef {
public:
explicit DispatchFFNCombine(const char *name) : OpDef(name) {
this->Input("a")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("w1")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("w2")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("expertIdx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale1")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale2")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("probs")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
// 输出
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
this->Attr("group").AttrType(REQUIRED).String();
this->Attr("M").AttrType(OPTIONAL).Int();
this->Attr("transB").AttrType(OPTIONAL).Bool(false);
this->Attr("weightNz").AttrType(OPTIONAL).Bool(false);
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_false")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
this->AICore().AddConfig("ascend910_93", aicore_config);
this->AICore().AddConfig("ascend910b", aicore_config);
this->MC2().HcclGroup("group");
}
};
OP_ADD(DispatchFFNCombine);
} // namespace ops

View File

@@ -0,0 +1,40 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dispatch_ffn_proto.cpp
* \brief
*/
#include <graph/utils/type_utils.h>
#include <register/op_impl_registry.h>
// #include "../../common/ophost/op_util.h"
// #include "../../common/ophost/hcom_topo_info.h"
// #include "log/ops_log.h"
using namespace ge;
namespace ops {
const size_t ATTR_GROUP = 0;
const size_t ATTR_RANK_SIZE = 1;
const size_t SUPPORT_DIM_SIZE = 2;
static ge::graphStatus InferShapeDispatchFFNCombine(gert::InferShapeContext* context) {
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeDispatchFFNCombine(gert::InferDataTypeContext* context) {
// auto d_type = context->GetInputDataType(0);
// context->SetOutputDataType(0, d_type);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(DispatchFFNCombine)
.InferShape(InferShapeDispatchFFNCombine)
.InferDataType(InferDataTypeDispatchFFNCombine);
} // namespace ops

View File

@@ -0,0 +1,265 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dispatch_ffn_tiling.cpp
* \brief
*/
#include "vector"
#include "register/tilingdata_base.h"
#include "tiling/tiling_api.h"
#include "error_log.h"
#include "hcom_topo_info.h"
#include "register/op_def_registry.h"
#include "dispatch_ffn_combine_tiling.h"
#include <vector>
#include <map>
#include <algorithm>
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
using namespace AscendC;
using namespace ge;
namespace {
// 1. 常量定义
const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug";
constexpr uint32_t ATTR_GROUP_INDEX = 0;
constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1;
constexpr uint32_t ATTR_IS_TRANS_B = 2;
constexpr uint32_t ATTR_WEIGHT_NZ = 3;
constexpr uint64_t INIT_TILINGKEY = 1000000;
constexpr uint64_t TILINGKEY_TRANS_B = 1U;
constexpr uint64_t TILINGKEY_WEIGHT_NZ = 10;
constexpr uint32_t X_INDEX = 0;
constexpr uint32_t WEIGHT_INDEX = 1;
constexpr uint32_t WEIGHT2_INDEX = 2;
constexpr uint32_t EXPERTID_INDEX = 3;
constexpr uint32_t BLOCK_NUM = 20;
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
}
namespace optiling {
static int32_t CeilDev(int32_t num, int32_t div)
{
if (div == 0) {
return 0;
}
return (num + div - 1) / div;
}
// 解析并校验 rankId, group, worldSize, isTransB 属性值
static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
{
auto attrs = context->GetAttrs();
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED);
// todoAttr相关tilingdata的设置、校验、打印
auto groupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
auto maxOutputSizePtr = attrs->GetAttrPointer<int>(ATTR_MAX_OUTPUT_SIZE_INDEX);
auto is_trans_b = attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B);
auto weight_nz = attrs->GetAttrPointer<bool>(ATTR_WEIGHT_NZ);
OP_TILING_CHECK(groupPtr == nullptr || strlen(groupPtr) == 0,
OP_LOGE(K_INNER_DEBUG, "group is invalid."), return GRAPH_FAILED);
OP_TILING_CHECK(is_trans_b == nullptr,
OP_LOGE(K_INNER_DEBUG, "is_trans_b is invalid."), return GRAPH_FAILED);
OP_TILING_CHECK(weight_nz == nullptr,
OP_LOGE(K_INNER_DEBUG, "weight_nz is invalid."), return GRAPH_FAILED);
info.maxOutputSize = *maxOutputSizePtr;
info.isTransposeB = *is_trans_b;
info.isWeightNz = *weight_nz;
int64_t rankSize;
(void)ge::HcomTopoInfo::Instance().GetGroupRankSize(groupPtr, rankSize);
info.worldSize = rankSize;
OP_LOGD(K_INNER_DEBUG, "maxOutputSize=%d ", info.maxOutputSize);
OP_LOGD(K_INNER_DEBUG, "rankSize=%d ", info.worldSize);
return ge::GRAPH_SUCCESS;
}
// 提取输入张量 A 和 B 的形状,计算出 M、K、N 值
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
{
const char *nodeName = context->GetNodeName();
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);
info.M = M;
info.N = N;
info.K = K;
info.expertPerRank = expertPerRank;
info.topK = topK;
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
return ge::GRAPH_SUCCESS;
}
// 获取当前芯片平台的 AI Core 数目、UB 容量等硬件信息。
static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
{
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSize = 0U;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
info.aivNum = aivNum;
info.totalUbSize = ubSize;
OP_LOGD(K_INNER_DEBUG, "aivNum=%d", info.aivNum);
OP_LOGD(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize);
return ge::GRAPH_SUCCESS;
}
void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineInfo &info)
{
cocTilingData.m0 = 128;
cocTilingData.k0 = 256;
cocTilingData.n0 = 256;
cocTilingData.swizzleDirect = 1;
cocTilingData.swizzleOffset = 7;
cocTilingData.ubMoveNum = 16 * 1024;
cocTilingData.pValue = 1;
cocTilingData.commNpuSplit = info.worldSize;
cocTilingData.commDataSplit = 1;
cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2;
}
// 主调度函数:
// 获取 tilingData ➝ 检查 Attr ➝ 检查 Shape ➝ 获取平台信息
// ➝ 调用 SetTilingData根据rank数目 ➝ 设置 blockDim ➝ 设置 tilingKey ➝ 设置 workspace ➝ 配置通信参数
static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context)
{
const char *nodeName = context->GetNodeName();
OP_LOGI(nodeName, "Enter DispatchFFNCombine tiling func.");
// 1. tilingData
DispatchFFNCombineTilingData *tilingData = context->GetTilingData<DispatchFFNCombineTilingData>();
OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."),
return ge::GRAPH_FAILED);
OP_LOGI(nodeName, "DispatchFFNCombine get tilingData.");
DispatchFFNCombineInfo& info = tilingData->dispatchFFNCombineInfo;
OP_LOGI(nodeName, "DispatchFFNCombine get tilingData info.");
OP_TILING_CHECK(DispatchFFNCombineCheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckAttrAndSetTiling Failed"),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(DispatchFFNCombineCheckShapeAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckShapeAndSetTiling Failed"),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(DispatchFFNCombineGetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "DispatchFFNCombine GetPlatformInfoAndSetTiling Failed"),
return ge::GRAPH_FAILED);
SetTilingData(tilingData->cocTiling, info);
// 2. set blockDim
uint32_t blockDim = 1U;
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
auto aicNum = ascendcPlatform.GetCoreNumAic();
auto aivNum = ascendcPlatform.GetCoreNumAiv();
blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum);
context->SetBlockDim(blockDim);
// 3. set tilingKey
uint64_t tilingKey = INIT_TILINGKEY;
tilingKey += info.isTransposeB ? TILINGKEY_TRANS_B : 0;
tilingKey += info.isWeightNz ? TILINGKEY_WEIGHT_NZ : 0;
context->SetTilingKey(tilingKey);
OP_LOGD(K_INNER_DEBUG, "tilingKey=%d", tilingKey);
optiling::MoeInitRoutingQuantV2TilingBase moeInitRoutingQuantV2TilingBase;
int64_t inuptXDtypeSize = sizeof(int16_t);
int64_t scaleDim0 = 0;
int64_t ubSize = 196352;
int64_t expertCapacity = 0;
int64_t expertNum = info.expertPerRank * info.worldSize;
int64_t activeNum = 0;
int64_t dropPadMode = 0;
int64_t expertTokensCountOrCumsumFlag = 2;
bool expertTokensBeforeCapacityFlag = false;
int64_t quantMode = 1;
uint32_t aivNumInitRouting = 2 * BLOCK_NUM;
moeInitRoutingQuantV2TilingBase.DoTiling(info.M, info.K, info.topK, expertCapacity, expertNum, activeNum, dropPadMode,
expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0, aivNumInitRouting, ubSize);
uint64_t initRoutingQuantTilingKey = moeInitRoutingQuantV2TilingBase.tilingKey_;
size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData = moeInitRoutingQuantV2TilingBase.quantTilingData;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vbsComputeParamsOp;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vmsMiddleComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vmsMiddleComputeParamsOp;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.sortOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.sortOutComputeParamsOp;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstComputeParamsOp;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstCapacityComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstCapacityComputeParamsOp;
tilingData->cocTiling.moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.gatherOutComputeParamsOp;
tilingData->cocTiling.initRoutingQuantTilingKey = initRoutingQuantTilingKey;
// 4. workspace
size_t *workSpaces = context->GetWorkspaceSizes(1);
OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."),
return ge::GRAPH_FAILED);
uint32_t n2 = info.K;
uint32_t k2 = info.N / 2;
uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) +
info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 +
info.maxOutputSize * sizeof(float) * 2 +
std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace);
// 5. communication
auto attrs = context->GetAttrs();
auto group = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
uint32_t opType = 8U;
std::string algConfig = "AlltoAll=level0:fullmesh;level1:pairwise";
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig);
mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling);
mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling);
OP_LOGI(nodeName, "Leave DispatchFFNCombine tiling func.");
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus DispatchFFNCombineTilingFunc(gert::TilingContext* context)
{
return DispatchFFNCombineTilingFuncImpl(context);
}
struct DispatchFFNCombineCompileInfo {};
ge::graphStatus TilingParseForDispatchFFNCombine(gert::TilingParseContext *context)
{
(void)context;
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(DispatchFFNCombine)
.Tiling(DispatchFFNCombineTilingFunc)
.TilingParse<DispatchFFNCombineCompileInfo>(TilingParseForDispatchFFNCombine);
} // namespace optiling

View File

@@ -0,0 +1,47 @@
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
#include <string>
#include "toolchain/slog.h"
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...) \
do { \
printf("[WARN][%s] ", (opname)); \
printf(__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
do { \
printf("[ERRORx][%s] ", (opname)); \
printf(__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGE(opname, ...) \
do { \
printf("[ERROR][%s] ", (opname)); \
printf(__VA_ARGS__); \
printf("\n"); \
} while (0)
#define OP_LOGD(opname, ...)
namespace optiling {
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
do { \
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
} while (0)
#define OP_TILING_CHECK(cond, log_func, expr) \
do { \
if (cond) { \
log_func; \
expr; \
} \
} while (0)
} // namespace optiling
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_

View File

@@ -0,0 +1,72 @@
/* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* ===================================================================================================================*/
#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_
#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_
#include <unordered_map>
#include <mutex>
using Status = int32_t;
namespace ge {
static constexpr uint32_t COMM_MESH = 0b1U;
static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U);
static constexpr uint32_t COMM_RING = (COMM_MESH << 2U);
static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U);
class HcomTopoInfo {
public:
enum class TopoLevel {
L0 = 0,
L1,
MAX,
};
struct TopoLevelDesc {
uint32_t comm_sets;
uint32_t rank_size;
};
using TopoDescs = TopoLevelDesc[static_cast<int32_t>(TopoLevel::MAX)];
struct TopoInfo {
int64_t rank_size;
void *notify_handle;
TopoDescs topo_level_descs;
};
static HcomTopoInfo &Instance();
bool TopoInfoHasBeenSet(const char_t *group);
bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info);
Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info);
Status GetGroupRankSize(const char_t *group, int64_t &rank_size);
TopoDescs *GetGroupTopoDesc(const char_t *group);
Status GetGroupNotifyHandle(const char_t *group, void *&notify_handle);
void UnsetGroupTopoInfo(const char_t *group) {
const std::lock_guard<std::mutex> lock(mutex_);
(void) rank_info_.erase(group);
}
Status SetGroupOrderedStream(const char_t *group, void *stream);
Status GetGroupOrderedStream(const char_t *group, void *&stream);
void UnsetGroupOrderedStream(const char_t *group) {
const std::lock_guard<std::mutex> lock(mutex_);
(void) group_to_ordered_stream_.erase(group);
};
Status SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream);
Status GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream);
void UnsetGroupOrderedStream(const int32_t device_id, const char_t *group);
private:
HcomTopoInfo() = default;
~HcomTopoInfo() = default;
std::unordered_map<std::string, TopoInfo> rank_info_;
std::mutex mutex_;
std::unordered_map<std::string, void*> group_to_ordered_stream_; // 通信域保序流
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // 通信域保序流
};
}
#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_

View File

@@ -0,0 +1,9 @@
#ifndef TILING_ARGS_H
#define TILING_ARGS_H
#include <cstdint>
namespace Moe {
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL;
constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL;
} // namespace Moe
#endif // TILING_ARGS_H

View File

@@ -0,0 +1,51 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file dispatch_ffn_combine.cpp
* \brief
*/
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "dispatch_ffn_combine_tiling.h"
#include "dispatch_ffn_combine.h"
using namespace AscendC;
using namespace DispatchFFNCombineImpl;
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
if (TILING_KEY_IS(1000000)) {
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000001)) {
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000011)) {
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Process();
}
}

View File

@@ -0,0 +1,276 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dispatch_ffn_combine.h
* \brief
*/
#ifndef DISPATCH_FFN_COMBINE_H
#define DISPATCH_FFN_COMBINE_H
using namespace AscendC;
#include "kernel_operator.h"
#include "utils/moe_distribute_base.h"
#include "dispatch_ffn_combine_tiling.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/epilogue/block/block_epilogue.hpp"
#include "catlass/epilogue/tile/tile_copy.hpp"
#include "catlass/epilogue/tile/tile_elemwise_add.hpp"
#include "catlass/epilogue/tile/tile_elemwise_muls.hpp"
#include "catlass/gemm/block/block_mmad.hpp"
#include "catlass/gemm/block/block_swizzle.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/kernel/matmul_epilogue.hpp"
#include "catlass/gemm/gemm_type.hpp"
#include "catlass/layout/layout.hpp"
#include "utils/select_helper.hpp"
#include "utils/const_args.hpp"
#include "dispatch_ffn_combine_kernel.hpp"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
using namespace Catlass;
namespace DispatchFFNCombineImpl {
#define TemplateMMA2AClass typename AType_, typename BType_, typename CType_, bool TB_, bool Nz_
#define TemplateMMA2ACFunc AType_, BType_, CType_, TB_, Nz_
using namespace AscendC;
template <TemplateMMA2AClass>
class DispatchFFNCombine {
public:
__aicore__ inline DispatchFFNCombine() {};
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
__aicore__ inline void Process();
private:
GM_ADDR xGM_;
GM_ADDR weight1GM_;
GM_ADDR weight2GM_;
GM_ADDR expertIdGM_;
GM_ADDR scale1GM_;
GM_ADDR scale2GM_;
GM_ADDR probs_;
GM_ADDR outGM_;
GM_ADDR workspaceGM_;
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
GM_ADDR moeInitRoutingQuantV2Offset = nullptr;
GM_ADDR expertTokensBeforeCapacity = nullptr;
TBuf<AscendC::TPosition::VECCALC> uBuf_;
int32_t rank;
int32_t rankSize;
int32_t aivNum;
int32_t m0;
int32_t k0;
int32_t n0;
int32_t swizzlOffset;
int32_t swizzlDirect;
int32_t ubMoveNum;
int32_t pValue;
int32_t commNpuSplit;
int32_t commDataSplit;
int32_t lenPerLoop;
int32_t m;
int32_t k;
int32_t n;
int32_t topK;
int32_t expertPerRank;
int32_t maxOutputSize;
int32_t EP;
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
uint64_t initRoutingQuantTilingKey;
// Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
};
template <TemplateMMA2AClass>
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
GET_TILING_DATA(tilingData, tilingGM);
xGM_ = xGM;
weight1GM_ = weight1GM;
weight2GM_ = weight2GM;
expertIdGM_ = expertIdGM;
scale1GM_ = scale1GM;
scale2GM_ = scale2GM;
probs_ = probs;
outGM_ = outGM;
workspaceGM_ = workspaceGM;
aivNum = tilingData.dispatchFFNCombineInfo.aivNum;
m = tilingData.dispatchFFNCombineInfo.M;
k = tilingData.dispatchFFNCombineInfo.K;
n = tilingData.dispatchFFNCombineInfo.N;
EP = tilingData.dispatchFFNCombineInfo.worldSize;
topK = tilingData.dispatchFFNCombineInfo.topK;
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
m0 = tilingData.cocTiling.m0;
k0 = tilingData.cocTiling.k0;
n0 = tilingData.cocTiling.n0;
swizzlDirect = tilingData.cocTiling.swizzleDirect;
swizzlOffset = tilingData.cocTiling.swizzleOffset;
ubMoveNum = tilingData.cocTiling.ubMoveNum;
pValue = tilingData.cocTiling.pValue;
commNpuSplit = tilingData.cocTiling.commNpuSplit;
commDataSplit = tilingData.cocTiling.commDataSplit;
lenPerLoop = tilingData.cocTiling.lenPerLoop;
moeInitRoutingQuantV2TilingData = tilingData.cocTiling.moeInitRoutingQuantV2TilingData;
initRoutingQuantTilingKey = tilingData.cocTiling.initRoutingQuantTilingKey;
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
rank = WinContext_->localUsrRankId;
rankSize = WinContext_->rankSize;
}
template <TemplateMMA2AClass>
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
{
// Define ArchTag
using ArchTag = Arch::AtlasA2;
constexpr bool enableUnitFlag = false;
constexpr bool enableShuffleK = true;
uint32_t k2 = n/2;
uint32_t n2 = k;
int64_t activeNum = 0;
int64_t expertCapacity = 0;
int64_t expertNum = expertPerRank * EP;
int64_t dropPadMode = 0;
int64_t expertTokensCountOrCumsumFlag = 2;
bool expertTokensBeforeCapacityFlag = false;
int64_t quantMode = 1;
using LayoutA = layout::RowMajor;
using LayoutB = typename std::conditional<
Nz_,
layout::zN,
typename std::conditional<TB_, layout::ColumnMajor, layout::RowMajor>::type
>::type;
LayoutB layoutB1 = LayoutBInitializer<LayoutB, BType_>::create(k, n);
LayoutB layoutB2 = LayoutBInitializer<LayoutB, BType_>::create(k2, n2);
using LayoutC = layout::RowMajor;
using L1TileShape = GemmShape<128, 256, 512>; // M, N, K
constexpr uint32_t workspaceStages = 2;
constexpr uint32_t preloadStages = 1;
constexpr uint32_t l1Stages = 2;
constexpr uint32_t l0AStages = 2;
constexpr uint32_t l0BStages = 2;
constexpr uint32_t l0CStages = 1;
using DispatchPolicy = Gemm::MmadAtlasA2PreloadAsyncFixpipe<
preloadStages,
l1Stages, l0AStages, l0BStages, l0CStages,
enableUnitFlag, enableShuffleK
>;
using L0TileShape = GemmShape<128, 256, 128>;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, LayoutB>;
using CType = Gemm::GemmType<float16_t, layout::RowMajor>;
using D1Type = Gemm::GemmType<int8_t, layout::RowMajor>;
using D2Type = typename std::conditional<
std::is_same_v<CType_, bfloat16_t>,
Gemm::GemmType<bfloat16_t, layout::RowMajor>,
Gemm::GemmType<CType_, layout::RowMajor>
>::type;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
constexpr uint32_t ubStages = 2;
using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant<ubStages>;
using ScaleType = Gemm::GemmType<uint64_t, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using ElementMulType = Gemm::GemmType<float, layout::RowMajor>;
using TileElemWiseMuls = Epilogue::Tile::TileElemWiseMuls<ArchTag, ElementMulType, 0>;
using TileCopy1 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D1Type>;
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
D1Type, TileElemWiseMuls, TileCopy1>;
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
D2Type, TileCopy2>;
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<9, 1>;
using ElementGroupList = int64_t;
using MatmulKernel = Gemm::Kernel::DispatchFFNCombineKernel<BlockMmad,
BlockScheduler, ElementGroupList, BlockEpilogue1, BlockEpilogue2>;
LayoutA layoutA1{static_cast<uint32_t>(m), static_cast<uint32_t>(k)};
LayoutA layoutA2{static_cast<uint32_t>(m), static_cast<uint32_t>(k2)};
layout::VectorLayout layoutScale1{static_cast<uint32_t>(n)};
layout::VectorLayout layoutScale2{static_cast<uint32_t>(n2)};
layout::RowMajor layoutD1{static_cast<uint32_t>(maxOutputSize), static_cast<uint32_t>(k2)};
layout::RowMajor layoutD2{static_cast<uint32_t>(m*topK), static_cast<uint32_t>(n2)};
// Prepare params
GemmCoord problemShape{static_cast<uint32_t>(m), static_cast<uint32_t>(n), static_cast<uint32_t>(k)};
uint32_t epilogueCoreNum = aivNum / 2;
uint32_t epilogueGranularity = expertPerRank - 1;
typename MatmulKernel::Params params{
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
epilogueCoreNum, epilogueGranularity,
xGM_, layoutA1, layoutA2,
weight1GM_, layoutB1,
weight2GM_, layoutB2,
scale1GM_, layoutScale1,
scale2GM_, layoutScale2,
outGM_, layoutD1, layoutD2,
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
expertTokensBeforeCapacity, probs_,
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
//Call kernel
MatmulKernel kernel(params);
kernel(params);
}
} // DispatchFFNCombineImpl
#endif // DISPATCH_FFN_COMBINE_H

View File

@@ -0,0 +1,814 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATH_FFN_COMBINE_KERNEL_HPP
#define DISPATH_FFN_COMBINE_KERNEL_HPP
#include "kernel_operator.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/epilogue/tile/tile_copy.hpp"
#include "utils/block_mmad_preload_async_fixpipe_quant.hpp"
#include "utils/copy_gm_to_l1_custom.hpp"
#include "utils/copy_l0c_to_gm_custom.hpp"
#include "utils/block_epilogue_pertoken_row.hpp"
#include "utils/block_epilogue_pertoken_swiglu.hpp"
#include "utils/hccl_shmem.hpp"
#include "utils/const_args.hpp"
#include "utils/layout3d.hpp"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h"
#include "unpermute/moe_token_unpermute.h"
using namespace AscendC;
namespace Catlass::Gemm::Kernel {
template <
class BlockMmad_,
class BlockScheduler_,
class ElementGroupList_,
class BlockEpilogue1_,
class BlockEpilogue2_
>
class DispatchFFNCombineKernel {
public:
using BlockMmad = BlockMmad_;
using ArchTag = typename BlockMmad::ArchTag;
using L1TileShape = typename BlockMmad::L1TileShape;
using ElementA = typename BlockMmad::ElementA;
using LayoutA = typename BlockMmad::LayoutA;
using ElementB = typename BlockMmad::ElementB;
using LayoutB = typename BlockMmad::LayoutB;
using ElementC = typename BlockMmad::ElementC;
using LayoutC = typename BlockMmad::LayoutC;
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using ElementScale = uint64_t;
using LayoutScale = typename layout::VectorLayout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = typename layout::VectorLayout;
using BlockScheduler = BlockScheduler_;
using BlockEpilogue1 = BlockEpilogue1_;
using BlockEpilogue2 = BlockEpilogue2_;
using ElementD1 = typename BlockEpilogue1::ElementD;
using LayoutD1 = typename BlockEpilogue1::LayoutD;
using ElementD2 = typename BlockEpilogue2::ElementD;
using LayoutD2 = typename BlockEpilogue2::LayoutD;
/// Parameters structure
struct Params {
// Data members
GemmCoord problemShape;
__gm__ ElementA *ptrA;
LayoutA layoutA;
LayoutA layoutA2;
__gm__ ElementB *ptrB1;
LayoutB layoutB1;
__gm__ ElementB *ptrB2;
LayoutB layoutB2;
__gm__ ElementScale *ptrScale1;
LayoutScale layoutScale1;
__gm__ ElementScale *ptrScale2;
LayoutScale layoutScale2;
__gm__ ElementD2 *ptrOutput;
LayoutD1 layoutD1;
LayoutD2 layoutD2;
GM_ADDR ptrWorkspace;
int32_t EP;
int32_t expertPerRank;
uint32_t maxOutputSize;
uint32_t rank;
uint32_t rankSize;
int32_t ubMoveNum;
//--------------
GM_ADDR expertIdx;
GM_ADDR moeInitRoutingQuantV2Scale;
GM_ADDR moeInitRoutingQuantV2Offset;
GM_ADDR expandedX;
GM_ADDR expandedRowIdx;
GM_ADDR expertTokensCountOrCumsum;
GM_ADDR expertTokensBeforeCapacity;
GM_ADDR dynamicQuantScale;
GM_ADDR probs;
int64_t topK;
uint64_t initRoutingQuantTilingKey;
uint32_t epilogueCoreNum;
uint32_t epilogueGranularity;
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
//--------------
// Methods
CATLASS_HOST_DEVICE
Params() {}
CATLASS_HOST_DEVICE
Params(
GemmCoord problemShape_,
uint32_t EP_, uint32_t expertPerRank_, uint32_t maxOutputSize_,
uint32_t rank_, uint32_t rankSize_, int64_t topK_,
uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_,
GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_,
GM_ADDR ptrB1_, LayoutB layoutB1_,
GM_ADDR ptrB2_, LayoutB layoutB2_,
GM_ADDR ptrScale1_, LayoutScale layoutScale1_,
GM_ADDR ptrScale2_, LayoutScale layoutScale2_,
GM_ADDR ptrOutput_, LayoutD2 layoutD1_, LayoutD2 layoutD2_,
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
GM_ADDR moeInitRoutingQuantV2Offset_,
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
) : problemShape(problemShape_),
EP(EP_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
rank(rank_), rankSize(rankSize_), topK(topK_),
initRoutingQuantTilingKey(initRoutingQuantTilingKey_),
epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_),
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_),
ptrB1(reinterpret_cast<__gm__ ElementB *>(ptrB1_)), layoutB1(layoutB1_),
ptrB2(reinterpret_cast<__gm__ ElementB *>(ptrB2_)), layoutB2(layoutB2_),
ptrScale1(reinterpret_cast<__gm__ ElementScale *>(ptrScale1_)), layoutScale1(layoutScale1_),
ptrScale2(reinterpret_cast<__gm__ ElementScale *>(ptrScale2_)), layoutScale2(layoutScale2_),
ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_),
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
{
}
};
// Methods
CATLASS_DEVICE
DispatchFFNCombineKernel(Params const &params)
{
if ASCEND_IS_AIC {
coreIdx = AscendC::GetBlockIdx();
coreNum = AscendC::GetBlockNum();
}
if ASCEND_IS_AIV {
coreIdx = get_block_idx() + get_subblockid() * get_block_num();
coreNum = get_block_num() * get_subblockdim();
}
initBuffer(params);
}
CATLASS_DEVICE
~DispatchFFNCombineKernel()
{
}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE
void operator()(Params const &params);
template <>
CATLASS_DEVICE
void operator()<AscendC::AIC>(Params const &params)
{
GMM1(params);
AscendC::CrossCoreWaitFlag<0x2>(2);
GMM2(params);
}
template <>
CATLASS_DEVICE
void operator()<AscendC::AIV>(Params const &params)
{
Dispatch(params);
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
Combine(params);
}
private:
CATLASS_DEVICE void initBuffer(Params const &params) {
workspaceInfo = WorkspaceInfo(params);
peermemInfo = PeermemInfo(params, shmem);
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
gmS.SetGlobalBuffer(params.ptrScale1);
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
gmS2.SetGlobalBuffer(params.ptrScale2);
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2));
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank + 8, params.expertPerRank);
}
template<typename T>
CATLASS_DEVICE void CopyGMToGM(
AscendC::GlobalTensor<T> dst,
AscendC::GlobalTensor<T> src,
int32_t elemNum,
int32_t ubMoveNum
)
{
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
using TType = Gemm::GemmType<T, layout::RowMajor>;
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
constexpr int32_t BufferNum = 2;
int tmpBufferSize = 32 * 1024 / sizeof(T); // 32 KB
AscendC::LocalTensor<T> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<T>(0);
tmpBuffer1.SetSize(tmpBufferSize);
int tmpBufferOffset = 96 * 1024; // half of UB
AscendC::LocalTensor<T> tmpBuffer2 = resource.ubBuf.template GetBufferByByte<T>(tmpBufferOffset);
tmpBuffer2.SetSize(tmpBufferSize);
// [ReduceScatter] 2. Pre Interface Sync
int pingpongId = 0;
auto processCount = CeilDiv(elemNum, ubMoveNum);
for (uint32_t processIndex = 0; processIndex < processCount; ++processIndex) {
uint32_t curProcessNum = (processIndex == processCount - 1) ? elemNum - ubMoveNum * (processCount - 1) : ubMoveNum;
AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1;
AscendC::LocalTensor<T> buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2;
auto processOffset = processIndex * ubMoveNum;
auto inputOffset = processOffset;
auto outputOffset = processOffset;
// [ReduceScatter] 2. Pre Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
// [ReduceScatter] 3. Start shmem_mte_get_mem_nbi
copyGmToUb(buf, src[inputOffset], layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum});
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID);
copyUbToGm(dst[outputOffset], buf, layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum});
// [ReduceScatter] 4. Post Interface Sync
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID);
pingpongId = (pingpongId + 1) % BufferNum;
}
// [ReduceScatter] 4. Post Interface Sync
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID1);
}
CATLASS_DEVICE
void GetCumsumForMMAIV(AscendC::GlobalTensor<int32_t> & tokenPerExpert, AscendC::GlobalTensor<int32_t> & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP)
{
int32_t expertPerRankAligned = (expertPerRank + 8 - 1) / 8 * 8;
AscendC::LocalTensor<int32_t> tmpBuffer1 = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::LocalTensor<int32_t> tmpResult = resource.ubBuf.template GetBufferByByte<int32_t>(EP * expertPerRank * sizeof(int32_t));
#define U16(x) static_cast<uint16_t>(x)
AscendC::DataCopyPad(
tmpBuffer1,
tokenPerExpert[rankId * expertPerRank],
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank + 8) * sizeof(int32_t)), 0},
{}
);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
for (uint32_t i = 1; i < EP; ++i) {
AscendC::Add(tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[(i - 1) * expertPerRankAligned], expertPerRank);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
AscendC::DataCopyPad(
result,
tmpBuffer1,
{U16(EP), U16((expertPerRank) * sizeof(int32_t)), 0, 0}
);
}
CATLASS_DEVICE
void GMM1(Params const &params){
icache_preload(8);
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
int64_t gmGroupOffsetC = 0;
uint32_t startCoreIdx = 0;
uint32_t syncGroupIdx = 0;
AscendC::CrossCoreWaitFlag<0x2>(0); // 等待aiv计算cumsumformm
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preCurrentmSum >= params.maxOutputSize) {
currentM = 0;
} else if (preCurrentmSum + currentM >= params.maxOutputSize) {
currentM = params.maxOutputSize - preCurrentmSum;
}
AscendC::GlobalTensor<ElementB> gmB1;
gmB1.SetGlobalBuffer(params.ptrB1);
if (currentM <= L1TileShape::M) {
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK());
LayoutB layoutB1 = params.layoutB1;
LayoutScale layoutScale = params.layoutScale1;
LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n());
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
// Determine the starting loopIdx of the current core under the current groupIdx
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
// Loop through the matmul of each groupIdx
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
for(;syncGroupIdx <= groupIdx; syncGroupIdx++) {
AscendC::CrossCoreWaitFlag<0x2>(0);
}
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
// Compute initial location in logical coordinates
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N};
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
if (currentM > 0) {
blockMmad(
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
gmB1[gmGroupOffsetB + gmOffsetB], layoutB1,
gmC[gmGroupOffsetC + gmOffsetC], layoutC,
gmS[gmOffsetS], layoutScale,
actualBlockShape
);
}
}
if ((groupIdx + 1) == params.epilogueGranularity && (groupIdx < params.expertPerRank - 1)) {
syncLoopIdx ++;
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(syncLoopIdx, 1);
}
preCurrentmSum += currentM;
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(syncLoopIdx + 1, 1);
}
CATLASS_DEVICE
void GMM2(Params const &params) {
icache_preload(8);
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
int64_t gmGroupOffsetC = 0;
uint32_t startCoreIdx = 0;
AscendC::PipeBarrier<PIPE_ALL>();
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
uint32_t lastDequantExpertNum = params.expertPerRank;
if (params.epilogueGranularity < params.expertPerRank) {
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (preCurrentmSum >= params.maxOutputSize) {
currentM = 0;
} else if (preCurrentmSum + currentM > params.maxOutputSize) {
currentM = params.maxOutputSize - preCurrentmSum;
}
AscendC::GlobalTensor<ElementB> gmB2;
gmB2.SetGlobalBuffer(params.ptrB2);
if (currentM <= L1TileShape::M) {
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
}
GemmCoord inGroupProblemShape{currentM, n2, k2}; // M N K
LayoutA layoutA = params.layoutA2.GetTileLayout(inGroupProblemShape.GetCoordMK());
LayoutB layoutB2 = params.layoutB2;
LayoutScale layoutScale = params.layoutScale2;
LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n());
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
// Determine the starting loopIdx of the current core under the current groupIdx
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
// Loop through the matmul of each groupIdx
if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) {
AscendC::CrossCoreWaitFlag<0x2>(2);
}
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
if (loopIdx + coreNum >= coreLoops) {
syncLoopIdx = groupIdx;
}
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
// Compute initial location in logical coordinates
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N};
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
if (currentM > 0) {
blockMmad(
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
gmB2[gmGroupOffsetB + gmOffsetB], layoutB2,
gmC2[gmGroupOffsetC + gmOffsetC], layoutC,
gmS2[gmOffsetS], layoutScale,
actualBlockShape, syncLoopIdx, 3
);
}
}
preCurrentmSum += currentM;
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
blockMmad.Finalize(params.expertPerRank - 1, 3);
}
CATLASS_DEVICE
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const &params, int64_t localTokenPerExpertOffset){
uint64_t flag_offset = (shmem.SegmentSize() - MB_SIZE) / sizeof(int32_t);
__gm__ int32_t* sync_base = shmem.SyncBaseAddr();
int count = gm_load(sync_base) + 1;
if (coreIdx < params.EP && coreIdx != params.rank) {
AscendC::GlobalTensor<int32_t> srcAddress;
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
AscendC::GlobalTensor<int32_t> dstAddress;
__gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx);
dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
using TType = Gemm::GemmType<int32_t, layout::RowMajor>;
using CopyGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, TType>;
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
CopyGmToUb copyGmToUb;
CopyUbToGm copyUbToGm;
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
uint32_t tmp = params.EP * params.expertPerRank;
copyGmToUb(tmpBuffer, srcAddress[0],
layout::RowMajor{ 1, tmp},
layout::RowMajor{1, tmp});
tmpBuffer.SetValue(params.EP * params.expertPerRank, count);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
copyUbToGm(dstAddress[0], tmpBuffer,
layout::RowMajor{ 1, tmp + 1},
layout::RowMajor{1, tmp + 1});
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(coreIdx, params.EP, 0);
gm_signal_wait_until_eq_for_barrier(sync_check, count);
}
AscendC::SyncAll<true>();
gm_store(sync_base, count);
}
CATLASS_DEVICE
void Dispatch(Params const &params) {
icache_preload(8);
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // 把通信矩阵全部放到peermem
uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
//---initRouting------
moe_init_routing_quant_v2<ElementD2>(reinterpret_cast<GM_ADDR> (params.ptrA), params.expertIdx,
params.moeInitRoutingQuantV2Scale, params.moeInitRoutingQuantV2Offset, shmem() + peermemInfo.offsetA,
workspaceInfo.expandedRowIdx, localTokenPerExpert, params.expertTokensBeforeCapacity,
shmem() + peermemInfo.offsetPeerPerTokenScale,
params.ptrWorkspace + expandedRowIdxOffset,
&params.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey);
AscendC::SyncAll<true>();
CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset);
if (coreIdx == 0) {
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0);
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t groupIdxDeq = 0;
if (coreIdx < params.EP) {
for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) {
prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i));
}
m_prevSumBeforeRank = prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t prevGroupSum1 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
BlockEpilogue1 blockEpilogue(resource);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// 第i个core从第i个rank的peermem读数据
groupIdxDeq = groupIdx - 2;
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
if (rowStart < params.maxOutputSize) {
uint32_t rows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
if (rowStart + rows > params.maxOutputSize) {
rows = params.maxOutputSize - rowStart;
}
uint32_t rowSrc = prevSum;
prevSum += rows;
GM_ADDR otherRankPtr = shmem(0, dstEpIdx);
AscendC::GlobalTensor<ElementA> gmRemoteA;
gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA));
AscendC::GlobalTensor<ElementPerTokenScale> gmRemotePerTokenScale;
gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale));
MatrixCoord offsetA{rowStart, 0};
MatrixCoord shapeA{rows, params.problemShape.k()};
MatrixCoord offsetPeer{rowSrc, 0};
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
// 通信Data
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
// 通信scale
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
}
}
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
syncLoopIdx++;
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V通知C当前轮的通信已完成
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
uint32_t dequantLen = prevGroupSum1 - dequantSum;
if (dequantLen >= params.maxOutputSize) {
dequantLen = dequantLen - params.maxOutputSize;
}
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
}
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
dequantSum = 0;
}
}
syncLoopIdx ++;
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1);
AscendC::SyncAll<true>();
uint32_t lastDequantExpertNum = params.expertPerRank;
if (params.epilogueGranularity < params.expertPerRank) {
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
if (lastDequantExpertNum < params.expertPerRank) {
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2);
}
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum;
if (prevGroupSum1 >= params.maxOutputSize) {
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
}
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
}
blockEpilogue.Finalize();
}
CATLASS_DEVICE
void Combine(Params const &params) {
int32_t prevSumBeforeRank = 0;
if (coreIdx < params.EP) {
prevSumBeforeRank = m_prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t n2 = params.problemShape.k();
uint32_t k2 = params.problemShape.n() / 2;
// TODO 计算tokenperexpert的cumsum
typename BlockEpilogue2::Params epilogueParams{
static_cast<int32_t>(params.EP),
static_cast<int32_t>(params.expertPerRank),
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace)
};
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
int32_t prevGroupSum2 = 0;
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3);
AscendC::SyncAll<true>();
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
__gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx);
AscendC::GlobalTensor<ElementD2> gmRemotePeer;
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr));
uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2;
if (srcRowOffset < params.maxOutputSize) {
uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
if (srcRowOffset + dataRows > params.maxOutputSize) {
dataRows = params.maxOutputSize - srcRowOffset;
}
uint32_t dstRowOffset = prevSum;
prevSum += dataRows;
MatrixCoord offsetC{srcRowOffset, 0};
MatrixCoord offsetPeer{dstRowOffset, 0};
MatrixCoord shapeC{dataRows, n2};
int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC);
int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer);
if constexpr (std::is_same_v<ElementA, int8_t>) {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]);
} else {
blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]);
}
}
}
prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
}
blockEpilogue.Finalize();
AscendC::SyncAll<true>();
shmem.CrossRankSync();
MoeTokenUnpermuteTilingData tilingData;
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum);
KernelMoeTokenUnpermute<ElementD2, int32_t, float, true> kernelMoeTokenUnpermuteOp;
kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast<GM_ADDR>(params.ptrOutput), &tilingData);
kernelMoeTokenUnpermuteOp.Process();
}
private:
struct WorkspaceInfo {
GM_ADDR ptrA;
GM_ADDR ptrPerTokenScale;
GM_ADDR ptrcumsumMM;
GM_ADDR ptrC;
GM_ADDR ptrC2;
GM_ADDR ptrPermutedToken;
GM_ADDR ptrPerTokenScale2;
GM_ADDR expandedRowIdx;
GM_ADDR ptrTokenPerExpert;
CATLASS_DEVICE
WorkspaceInfo(){}
CATLASS_DEVICE
WorkspaceInfo(const Params & params) {
uint32_t k2 = params.problemShape.n() / 2;
uint32_t n2 = params.problemShape.k();
int64_t workspaceOffset = 0;
expandedRowIdx = params.ptrWorkspace;
workspaceOffset += AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
ptrcumsumMM = params.ptrWorkspace + workspaceOffset;
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
ptrPerTokenScale = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale);
ptrPerTokenScale2 = params.ptrWorkspace + workspaceOffset;
workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale);
ptrTokenPerExpert = params.ptrWorkspace + workspaceOffset;
workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t);
ptrC = params.ptrWorkspace + workspaceOffset;
ptrC2 = ptrC;
workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC),
params.maxOutputSize * n2 * sizeof(ElementC));
ptrA = params.ptrWorkspace + workspaceOffset;
ptrPermutedToken = ptrA;
workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA),
params.maxOutputSize * k2 * sizeof(ElementA));
}
};
struct PeermemInfo {
int64_t offsetA;
int64_t offsetPeerPerTokenScale;
int64_t offsetPeerTokenPerExpert;
int64_t offsetD;
CATLASS_DEVICE
PeermemInfo(){}
CATLASS_DEVICE
PeermemInfo(const Params & params, const HcclShmem & shmem) {
offsetA = 0; // 占用1/3的BUFFSIZE
offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); // 占用1MB
offsetD = offsetPeerPerTokenScale + MB_SIZE; // 占用剩下的
offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; // 占用最后2MB
}
};
Arch::Resource<ArchTag> resource;
uint32_t coreIdx;
uint32_t coreNum;
Params params;
WorkspaceInfo workspaceInfo;
PeermemInfo peermemInfo;
int64_t m_prevSumBeforeRank;
AscendC::GlobalTensor<ElementA> gmA;
AscendC::GlobalTensor<ElementC> gmC;
AscendC::GlobalTensor<ElementScale> gmS;
AscendC::GlobalTensor<ElementD1> gmPermutedToken;
AscendC::GlobalTensor<ElementScale> gmS2;
AscendC::GlobalTensor<ElementC> gmC2;
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale1;
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale2;
AscendC::GlobalTensor<int32_t> tokenPerExpert;
AscendC::GlobalTensor<int32_t> cumsumMM;
Layout3D tokenPerExpertLayout;
HcclShmem shmem;
};
} // namespace Catlass::Gemm::Kernel
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP

View File

@@ -0,0 +1,56 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file dispatch_ffn_combine_tiling.h
* \brief
*/
#include "moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h"
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#ifndef ASCENDC_DISPATCH_FFN_COMBINE_TILING_H
#define ASCENDC_DISPATCH_FFN_COMBINE_TILING_H
struct DispatchFFNCombineInfo {
uint32_t M;
uint32_t K;
uint32_t N;
uint32_t expertPerRank;
uint32_t maxOutputSize;
uint32_t isTransposeB;
uint32_t isWeightNz;
uint32_t aivNum;
uint32_t totalUbSize;
uint32_t topK;
uint32_t worldSize;
};
struct CoCTiling {
int32_t m0 = -1;
int32_t k0 = -1;
int32_t n0 = -1;
int32_t swizzleDirect = -1;
int32_t swizzleOffset = -1;
int32_t ubMoveNum = -1;
int32_t pValue = -1;
int32_t commNpuSplit = -1;
int32_t commDataSplit = -1;
int32_t lenPerLoop = -1;
uint64_t initRoutingQuantTilingKey;
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
};
struct DispatchFFNCombineTilingData {
Mc2InitTiling mc2InitTiling;
Mc2CcTiling mc2CcTiling;
DispatchFFNCombineInfo dispatchFFNCombineInfo;
CoCTiling cocTiling;
};
#endif

View File

@@ -0,0 +1,134 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_init_routing_quant_v2.cpp
* \brief
*/
#include "moe_v2_sort_one_core.h"
#include "moe_v2_sort_multi_core.h"
#include "moe_v2_mrgsort_out.h"
#include "moe_v2_mrgsort.h"
#include "moe_v2_expert_token_out.h"
#include "moe_v2_src_to_dst_op.h"
#include "moe_v2_src_to_dst_with_capacity.h"
#include "moe_v2_fullload_quant.h"
#include "moe_v2_fullload_dynamic_quant.h"
#include "moe_v2_gather_quant.h"
#include "moe_v2_gather_dynamic_quant.h"
#include "moe_v2_src_to_dst_and_gather.h"
using namespace AscendC;
using namespace MoeInitRoutingQuantV2;
using namespace optiling;
template <class DTYPE_X = bfloat16_t>
__aicore__ inline void moe_init_routing_quant_v2(
GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, GM_ADDR dynamicQuantScale, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, uint64_t tilingKey) {
if (g_coreType == AIC) {
return;
}
if (workspace == nullptr) {
return;
}
if (tilingKey == 20000) { // quant full load
TPipe sortPipe;
MoeV2FullLoadQuant<DTYPE_X> op;
op.Init(x, expertIdx, scale, offset, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, &sortPipe);
op.Process();
sortPipe.Destroy();
return;
}
else if (tilingKey == 21000) { // dynamic quant full load
TPipe sortPipe;
MoeV2FullLoadDynamicQuant<DTYPE_X> op;
op.Init(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, scale, dynamicQuantScale, workspace, tilingData,
&sortPipe);
op.Process();
sortPipe.Destroy();
return;
}
// sort
if (tilingKey == 10000 || tilingKey == 10100 || tilingKey == 11000 || tilingKey == 11100) {
TPipe sortPipe;
MoeV2SortOneCore op;
op.Init<MoeInitRoutingQuantV2TilingData>(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace,
tilingData, &sortPipe);
op.Process();
sortPipe.Destroy();
} else if (tilingKey == 10010 || tilingKey == 10110 || tilingKey == 11010 || tilingKey== 11110) {
TPipe sortPipe;
MoeV2SortMultiCore op;
op.Init<MoeInitRoutingQuantV2TilingData>(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace,
tilingData, &sortPipe);
op.Process();
sortPipe.Destroy();
}
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey ==11000 || tilingKey ==11010) { //没有drop的情况
if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) {
TPipe expertTokenOutPipe;
MoeV2ExpertTokenOut expertTokenOutOp;
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
expandedRowIdx, workspace, tilingData, &expertTokenOutPipe);
expertTokenOutOp.Process();
expertTokenOutPipe.Destroy();
}
TPipe srcToDstPipe;
MoeV2SrcToDstOp srcToDstOp;
srcToDstOp.Init<MoeInitRoutingQuantV2TilingData>(expandedRowIdx, workspace, tilingData, &srcToDstPipe);
srcToDstOp.Process();
srcToDstPipe.Destroy();
} else if (tilingKey ==10100 || tilingKey ==10110 || tilingKey ==11100 || tilingKey ==11110) { //有drop的情况
TPipe expertTokenOutPipe;
MoeV2ExpertTokenOut expertTokenOutOp;
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
expandedRowIdx, workspace, tilingData, &expertTokenOutPipe);
expertTokenOutOp.Process();
expertTokenOutPipe.Destroy();
if (tilingKey == 10100 || tilingKey == 10110) {
TPipe srcToDstPipe;
MoeV2SrcToDstWithCapacity<int8_t, MoeInitRoutingQuantV2TilingData> srcToDstWithCapacityOp;
srcToDstWithCapacityOp.Init(expandedRowIdx, expandedX, workspace, tilingData, &srcToDstPipe);
srcToDstWithCapacityOp.Process();
srcToDstPipe.Destroy();
} else {
TPipe srcToDstGatherPipe;
MoeV2SrcToDstAndGather<DTYPE_X, MoeInitRoutingQuantV2TilingData> srcToDstAndGatherOp;
srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe);
srcToDstAndGatherOp.Process();
srcToDstGatherPipe.Destroy();
return;
}
}
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey == 10100 || tilingKey == 10110) {
TPipe gatherPipe;
MoeV2GatherQuant<DTYPE_X> gatherQuantOp;
gatherQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, workspace, tilingData, &gatherPipe);
gatherQuantOp.Process();
gatherPipe.Destroy();
} else if (tilingKey == 11000 || tilingKey == 11010) {
TPipe gatherPipe;
MoeV2GatherDynamicQuant<DTYPE_X> gatherDynamicQuantOp;
gatherDynamicQuantOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &gatherPipe);
gatherDynamicQuantOp.Process();
gatherPipe.Destroy();
}
}

View File

@@ -0,0 +1,429 @@
#pragma once
#include "moe_init_routing_v2_tiling.h"
namespace optiling {
const static int64_t ATTR_QUANT_MODE = 6;
const static int64_t TILING_KEY_BASE = 10000;
const static int64_t TILING_KEY_PERF_BASE = 20000;
const static int64_t TILING_KEY_QUANT_BASE = 1000;
const static int64_t TILING_KEY_DROP_MODE_BASE = 100;
const static int64_t TILING_KEY_SORT_BASE = 10;
const static int64_t FOUR_BLOCK_BYTE = 128;
const static int64_t MAX_COLS_ONE_LOOP_QUANT = 8192;
const static int64_t INDEX_SCALE = 2;
const static int64_t INDEX_OFFSET = 3;
const static int64_t SMOOTH_NONE = 0;
const static int64_t SMOOTH_1H = 1;
const static int64_t SMOOTH_EH = 2;
const static int64_t MAX_COLS_DYNAMIC_QUANT = 6144;
const static int64_t DYNAMIC_QUANT_SRC_TO_DST_BUFFER = 15;
const static int64_t DYNAMIC_QUANT_COLS_BUFFER = 21;
const static int64_t DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER = 13;
const static int64_t DYNAMIC_QUANT_SCALE_SIZE_64 = 64;
const static int64_t DYNAMIC_QUANT_SCALE_SIZE_128 = 128;
const static int64_t OUTOUT_DYNAMIC_QUANT_SCALE = 4;
const static int64_t FULLLOAD_H_LIMIT = 7168;
inline static int64_t AlignOneBlockByte(int64_t x) {
return (x + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
}
inline static int64_t AlignOneBlockByteCeil(int64_t x) {
return x / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
}
struct MoeInitRoutingQuantV2TilingData {
int64_t coreNum;
int64_t n;
int64_t cols;
int64_t k;
int64_t expertCapacity;
int64_t expertNum;
int64_t dropPadMode;
int64_t expertTokensCountOrCumsumFlag;
int64_t expertTokensBeforeCapacityFlag;
int64_t smoothType;
InnerMoeV2VBSComputeTilingData vbsComputeParamsOp;
InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp;
InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp;
};
class MoeInitRoutingQuantV2TilingBase : public InnerMoeInitRoutingV2TilingBase {
public:
protected:
bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override;
uint64_t GetTilingKey() const override;
bool GetWorkspaceSize() override;
bool PostTiling() override;
public:
//bool CheckOutShape() override;
bool IsFullLoadQuant(int64_t space);
bool IsFullLoadDynamicQuant(int64_t space);
bool IsFullLoad() override;
void SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows, int64_t lastCoreRows,
int64_t cols);
void SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t baseMaxCols, int64_t cols);
void SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows,
int64_t lastCoreRows, int64_t basePerLoopMaxRows);
void Tiling4GatherQuant();
void Tiling4GatherDynamicQuant();
void Tiling4SrcToDstCapacityCompute() override;
void Tiling4GatherOutCompute() override;
void CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst, InnerMoeV2GatherOutComputeTilingData& src);
void CopyTilingData();
int64_t quantMode;
MoeInitRoutingQuantV2TilingData quantTilingData;
};
bool MoeInitRoutingQuantV2TilingBase::IsFullLoadQuant(int64_t space) {
int64_t perCoreXRows = moeInitRoutingTilingData.n / aivNum;
int64_t remainder = moeInitRoutingTilingData.n % aivNum;
// NUM_TWO is Max xRows need add 2 becauseof the left and right row may be another row.
perCoreXRows = remainder <= 1 ? perCoreXRows + 1 : perCoreXRows + NUM_TWO;
int64_t quantBaseSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols);
int64_t quantSpace =
quantBaseSpace * (inuptXDtypeSize_ + sizeof(int8_t) + sizeof(float) + sizeof(int16_t)) * perCoreXRows;
int64_t remainUbAfterSort = aicoreParams_.ubSize - space - quantSpace;
return remainUbAfterSort > 0;
}
bool MoeInitRoutingQuantV2TilingBase::IsFullLoadDynamicQuant(int64_t space) {
int64_t quantSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols) * DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER;
int64_t scaleOutSpace = 64;
int64_t remainUbAfterSort = aicoreParams_.ubSize - space - scaleOutSpace - quantSpace;
return remainUbAfterSort > 0;
}
bool MoeInitRoutingQuantV2TilingBase::IsFullLoad() {
if (totalLength > sortLoopMaxElement || moeInitRoutingTilingData.cols > MAX_COLS_ONE_LOOP_QUANT ||
this->dropPadMode == 1) {
return false;
}
int64_t sortSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * ONE_CORE_SORT_BUFFER;
int64_t otherSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * NUM_THREE;
int64_t expertSpace = AlignOneBlockByte(this->expertNum * sizeof(int32_t));
if (quantMode == 0) {
return IsFullLoadQuant(sortSpace + otherSpace + expertSpace);
} else {
return IsFullLoadDynamicQuant(sortSpace + otherSpace + expertSpace);
}
}
bool MoeInitRoutingQuantV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) {
InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode,
expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0);
this -> quantMode = quantMode;
if (quantMode == 0) {
} else {
if (scaleDim0 > 0) {
quantTilingData.smoothType = ((scaleDim0 == 1) ? SMOOTH_1H : SMOOTH_EH);
} else {
quantTilingData.smoothType = SMOOTH_NONE;
}
}
return true;
}
uint64_t MoeInitRoutingQuantV2TilingBase::GetTilingKey() const {
if (isFullLoad) {
return TILING_KEY_PERF_BASE + quantMode * TILING_KEY_QUANT_BASE;
}
return TILING_KEY_BASE + quantMode * TILING_KEY_QUANT_BASE + dropPadMode * TILING_KEY_DROP_MODE_BASE +
(totalLength > sortLoopMaxElement) * TILING_KEY_SORT_BASE;
}
bool MoeInitRoutingQuantV2TilingBase::PostTiling() {
CopyTilingData();
return true;
}
void MoeInitRoutingQuantV2TilingBase::CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst,
InnerMoeV2GatherOutComputeTilingData& src) {
dst.needCoreNum = (src.needCoreNum);
dst.activateRows = (src.activateRows);
dst.perCoreRows = (src.perCoreRows);
dst.perCorePerLoopRows = (src.perCorePerLoopRows);
dst.perCoreLastLoopRows = (src.perCoreLastLoopRows);
dst.lastCoreRows = (src.lastCoreRows);
dst.lastCorePerLoopRows = (src.lastCorePerLoopRows);
dst.lastCoreLastLoopRows = (src.lastCoreLastLoopRows);
dst.perCoreLoops = (src.perCoreLoops);
dst.lastCoreLoops = (src.lastCoreLoops);
dst.perLoopCols = (src.perLoopCols);
dst.lastLoopCols = (src.lastLoopCols);
dst.colLoops = (src.colLoops);
}
void MoeInitRoutingQuantV2TilingBase::CopyTilingData() {
quantTilingData.coreNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.coreNum);
quantTilingData.n = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.n);
quantTilingData.cols = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols);
quantTilingData.k = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.k);
quantTilingData.expertCapacity = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertCapacity);
quantTilingData.expertNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertNum);
quantTilingData.dropPadMode = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.dropPadMode);
quantTilingData.expertTokensCountOrCumsumFlag = (
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensCountOrCumsumFlag);
quantTilingData.expertTokensBeforeCapacityFlag = (
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensBeforeCapacityFlag);
auto vbsTilingData = &InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vbsComputeParamsOp;
quantTilingData.vbsComputeParamsOp.needCoreNum = (vbsTilingData->needCoreNum);
quantTilingData.vbsComputeParamsOp.perCoreElements = (vbsTilingData->perCoreElements);
quantTilingData.vbsComputeParamsOp.perCoreLoops = (vbsTilingData->perCoreLoops);
quantTilingData.vbsComputeParamsOp.perCorePerLoopElements = (vbsTilingData->perCorePerLoopElements);
quantTilingData.vbsComputeParamsOp.perCoreLastLoopElements = (vbsTilingData->perCoreLastLoopElements);
quantTilingData.vbsComputeParamsOp.lastCoreElements = (vbsTilingData->lastCoreElements);
quantTilingData.vbsComputeParamsOp.lastCoreLoops = (vbsTilingData->lastCoreLoops);
quantTilingData.vbsComputeParamsOp.lastCorePerLoopElements = (vbsTilingData->lastCorePerLoopElements);
quantTilingData.vbsComputeParamsOp.lastCoreLastLoopElements = (vbsTilingData->lastCoreLastLoopElements);
quantTilingData.vbsComputeParamsOp.oneLoopMaxElements = (vbsTilingData->oneLoopMaxElements);
quantTilingData.vmsMiddleComputeParamsOp.needCoreNum = (
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vmsMiddleComputeParamsOp.needCoreNum);
quantTilingData.sortOutComputeParamsOp.oneLoopMaxElements = (
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.sortOutComputeParamsOp.oneLoopMaxElements);
CopyGatherOutTiling(quantTilingData.srcToDstComputeParamsOp,
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstComputeParamsOp);
CopyGatherOutTiling(quantTilingData.srcToDstCapacityComputeParamsOp,
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp);
}
bool MoeInitRoutingQuantV2TilingBase::GetWorkspaceSize() {
InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize();
bool useCols =
(dropPadMode == 0 && quantTilingData.gatherOutComputeParamsOp.colLoops > 1) ||
(dropPadMode == 1 &&
InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.colLoops > 1);
if (quantMode == 1 && useCols) {
workspaceSize_ += aivNum * InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols * sizeof(float);
}
return true;
}
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData,
int64_t perCoreRows, int64_t lastCoreRows, int64_t cols) {
tilingData->perCorePerLoopRows = perCoreRows;
tilingData->perCoreLastLoopRows = perCoreRows;
tilingData->lastCorePerLoopRows = lastCoreRows;
tilingData->lastCoreLastLoopRows = lastCoreRows;
tilingData->perCoreLoops = 1;
tilingData->lastCoreLoops = 1;
tilingData->perLoopCols = cols;
tilingData->lastLoopCols = cols;
tilingData->colLoops = 1;
}
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData,
int64_t baseMaxCols, int64_t cols) {
tilingData->perLoopCols = (std::min(baseMaxCols, cols));
tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols));
tilingData->colLoops = (baseMaxCols == 0 ? 0 : (cols + baseMaxCols - 1) / baseMaxCols);
}
void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData,
int64_t perCoreRows, int64_t lastCoreRows,
int64_t basePerLoopMaxRows) {
tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows));
tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows));
tilingData->perCoreLoops = (basePerLoopMaxRows == 0 ? 0
: (perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows));
tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows));
tilingData->lastCoreLoops = (basePerLoopMaxRows == 0 ? 0
: (lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
}
void MoeInitRoutingQuantV2TilingBase::Tiling4SrcToDstCapacityCompute() {
if (quantMode == 0 || dropPadMode == 0) {
InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute();
return;
}
auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp;
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
if (perCoreRows <= 0) {
tilingData->needCoreNum = 0;
return;
}
tilingData->needCoreNum = CeilDiv(totalLength, perCoreRows);
int64_t cols = moeInitRoutingTilingData.cols;
tilingData->perCoreRows = perCoreRows;
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
tilingData->lastCoreRows = lastCoreRows;
int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR;
int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64;
if (rowSize + colSize + scaleSize < static_cast<int64_t>(aicoreParams_.ubSize)) {
SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols);
} else {
int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT;
int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
int64_t basePerLoopMaxRows =
AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
if (cols < MAX_COLS_DYNAMIC_QUANT) {
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
} else if (perCoreRows < basePerLoopMaxRows) {
baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_SRC_TO_DST_BUFFER;
}
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
}
}
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherQuant() {
auto tilingData = &quantTilingData.gatherOutComputeParamsOp;
tilingData->activateRows = totalLength;
if (dropPadMode == 0 && activateNum > 0) {
tilingData->activateRows = (std::min(activateNum, totalLength));
}
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
if (perCoreRows <= 0) {
tilingData->needCoreNum = 0;
return;
}
tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows));
int64_t cols = moeInitRoutingTilingData.cols;
tilingData->perCoreRows = perCoreRows;
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
tilingData->lastCoreRows = lastCoreRows;
int64_t sizeOfCol = sizeof(int8_t) * NUM_TWO + sizeof(float) + sizeof(int16_t) + inuptXDtypeSize_ * NUM_TWO;
int64_t rowSize = AlignOneBlockByte((perCoreRows * sizeof(int32_t) * NUM_TWO));
int64_t colSize = AlignOneBlockByte(cols * sizeOfCol);
if (rowSize + colSize < static_cast<int64_t>(aicoreParams_.ubSize) / NUM_TWO) {
SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols);
} else {
int64_t baseMaxCols = MAX_COLS_ONE_LOOP_QUANT;
int64_t baseMaxColsSize = AlignOneBlockByte(baseMaxCols * sizeOfCol);
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
int64_t basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - baseMaxColsSize) / NUM_TWO / sizeof(int32_t));
if (cols < MAX_COLS_ONE_LOOP_QUANT) {
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize) / NUM_TWO / sizeof(int32_t));
} else if (perCoreRows < basePerLoopMaxRows) {
baseMaxCols = AlignOneBlockByteCeil((ubSize - rowSize) / sizeOfCol);
}
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
}
}
void SetGatherTilingDatawithloop(InnerMoeV2GatherOutComputeTilingData* tilingData,
int64_t perCorePerLoopRows, int64_t lastCorePerLoopRows, int64_t cols,
int64_t perCoreLastLoopRows = 1, int64_t lastCoreLastLoopRows = 1,
int64_t perCoreLoops = 1, int64_t lastCoreLoops = 1) {
tilingData-> perCorePerLoopRows = perCorePerLoopRows;
tilingData-> perCoreLastLoopRows = perCoreLastLoopRows;
tilingData-> lastCorePerLoopRows = lastCorePerLoopRows;
tilingData-> lastCoreLastLoopRows = lastCoreLastLoopRows;
tilingData-> perCoreLoops = perCoreLoops;
tilingData-> lastCoreLoops = lastCoreLoops;
tilingData-> perLoopCols = cols;
tilingData-> lastLoopCols = cols;
tilingData-> colLoops = 1;
}
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherDynamicQuant() {
auto tilingData = &quantTilingData.gatherOutComputeParamsOp;
tilingData->activateRows = totalLength;
if (dropPadMode == 0 && activateNum > 0) {
tilingData->activateRows = (std::min(activateNum, totalLength));
}
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
if (perCoreRows <= 0) {
tilingData->needCoreNum = 0;
return;
}
tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows));
int64_t cols = InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols;
tilingData->perCoreRows = perCoreRows;
int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
tilingData->lastCoreRows = lastCoreRows;
int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR;
int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER;
int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64;
int64_t onceRowSize = (static_cast<int64_t>(aicoreParams_.ubSize) -
colSize - scaleSize -
ONE_BLOCK_BYTE * NUM_FOUR * NUM_THREE) /
(sizeof(int32_t) * NUM_FOUR);
int64_t oneBlockNumInt = static_cast<int64_t>(ONE_BLOCK_BYTE) / static_cast<int64_t>(sizeof(int32_t));
onceRowSize = onceRowSize / oneBlockNumInt * oneBlockNumInt;
bool ifOneLoop = ((static_cast<int64_t>(aicoreParams_.ubSize) > colSize +
scaleSize + ONE_BLOCK_BYTE * NUM_FOUR * NUM_FOUR) &&
quantTilingData.smoothType == SMOOTH_NONE &&
cols == FULLLOAD_H_LIMIT);
int64_t perCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, perCoreRows) : perCoreRows;
int64_t lastCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, lastCoreRows) : lastCoreRows;
int64_t perCoreLoops = ifOneLoop ? CeilDiv(perCoreRows, perCoreOnceRowSize) : 1;
int64_t lastCoreLoops = ifOneLoop ? CeilDiv(lastCoreRows, lastCoreOnceRowSize) : 1;
int64_t perCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(perCoreRows, perCoreOnceRowSize) : perCoreRows;
int64_t lastCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(lastCoreRows, lastCoreOnceRowSize) : lastCoreRows;
if (rowSize + colSize + scaleSize < static_cast<int64_t>(aicoreParams_.ubSize) || ifOneLoop) {
SetGatherTilingDatawithloop(tilingData, perCoreOnceRowSize, lastCoreOnceRowSize, cols,
perCoreLastLoopRows, lastCoreLastLoopRows,
perCoreLoops, lastCoreLoops);
} else {
int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT;
int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER;
int64_t ubSize = static_cast<int64_t>(aicoreParams_.ubSize);
int64_t basePerLoopMaxRows =
AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
if (cols < MAX_COLS_DYNAMIC_QUANT) {
basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR;
} else if (perCoreRows < basePerLoopMaxRows) {
baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_COLS_BUFFER;
}
SetGatherTilingDataCols(tilingData, baseMaxCols, cols);
SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows);
}
}
void MoeInitRoutingQuantV2TilingBase::Tiling4GatherOutCompute() {
if (quantMode == 0) {
Tiling4GatherQuant();
} else {
Tiling4GatherDynamicQuant();
}
}
}

View File

@@ -0,0 +1,410 @@
#pragma once
#include "tiling_base.h"
namespace optiling {
const static int64_t TILING_KEY_DROPLESS_SORT_ONE_CORE = 10001;
const static int64_t TILING_KEY_DROPLESS_SORT_MULTI_CORE = 10002;
const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE = 10011;
const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE = 10012;
const static int64_t TILING_KEY_HIGH_PERFORMANCE = 20000;
const static int64_t NUM_TWO = 2;
const static int64_t NUM_THREE = 3;
const static int64_t NUM_FOUR = 4;
const static int64_t MRG_LIST_NUM = 4;
const static int64_t SORT32_ALIGN_ELEMENT = 32;
const static int64_t ONE_BLOCK_BYTE = 32;
const static size_t DIM_ONE = 1;
const static size_t DIM_TWO = 2;
const static size_t DIM_THREE = 3;
const static int32_t SIZE_16 = 16;
const static int32_t LENGTH_1024 = 1024;
const static int64_t MAX_COLS_ONE_LOOP = 16376;
const static int64_t ASSIST_NUM = 256;
const static int64_t INDEX_INPUT_X = 0;
const static int64_t INDEX_INPUT_EXPERT_IDX = 1;
const static int64_t ATTR_ACTIVE_ROWS = 0;
const static int64_t ATTR_EXPERT_CAPACITY = 1;
const static int64_t ATTR_EXPERT_NUM = 2;
const static int64_t ATTR_DROP_PAD_MODE = 3;
const static int64_t ATTR_EXPERT_TOKENS_COUNT_OR_CUMSUM_FLAG = 4;
const static int64_t ATTR_EXPERT_TOKENS_BEFORE_CAPACITY_FLAG = 5;
const static int64_t OUTOUT_EXPANDED_X = 0;
const static int64_t OUTOUT_EXPANDED_ROW_IDX = 1;
const static int64_t OUTOUT_EXPERT_TOKENS_COUNT_OR_CUMSUM = 2;
const static int64_t OUTOUT_EXPERT_TOKENS_BEFORE_CAPACITY = 3;
const static int64_t KV_FACTOR = 2;
const static int64_t ONE_CORE_SORT_BUFFER = 6;
const static int64_t EXPERT_TOKENS_COUNT = 2;
inline static int64_t CeilLog4(int64_t x) {
return static_cast<int64_t>(std::ceil(std::log(x) / std::log(NUM_FOUR)));
}
inline static int64_t GetPerOrLastValue(int64_t x, int64_t y) {
if (y == 0) {
return 0;
}
return x <= y ? x : x % y;
}
template <class T>
constexpr T CeilDiv(const T dividend, const T divisor)
{
return (dividend + divisor - 1) / divisor;
}
struct InnerMoeV2VBSComputeTilingData {
int64_t needCoreNum = 0;
int64_t perCoreElements = 0;
int64_t perCoreLoops = 0;
int64_t perCorePerLoopElements = 0;
int64_t perCoreLastLoopElements = 0;
int64_t lastCoreElements = 0;
int64_t lastCoreLoops = 0;
int64_t lastCorePerLoopElements = 0;
int64_t lastCoreLastLoopElements = 0;
int64_t oneLoopMaxElements = 0;
};
struct InnerMoeV2VMSMiddleComputeTilingData {
int64_t needCoreNum = 0;
};
struct InnerMoeV2SortOutComputeTilingData {
int64_t oneLoopMaxElements = 0;
};
struct InnerMoeV2GatherOutComputeTilingData {
int64_t needCoreNum = 0;
int64_t activateRows = 0;
int64_t perCoreRows = 0;
int64_t perCorePerLoopRows = 0;
int64_t perCoreLastLoopRows = 0;
int64_t lastCoreRows = 0;
int64_t lastCorePerLoopRows = 0;
int64_t lastCoreLastLoopRows = 0;
int64_t perCoreLoops = 0;
int64_t lastCoreLoops = 0;
int64_t perLoopCols = 0;
int64_t lastLoopCols = 0;
int64_t colLoops = 0;
};
struct InnerMoeInitRoutingV2TilingData {
int64_t coreNum;
int64_t n;
int64_t cols;
int64_t k;
int64_t expertCapacity;
int64_t expertNum;
int64_t dropPadMode;
int64_t expertTokensCountOrCumsumFlag;
int64_t expertTokensBeforeCapacityFlag;
InnerMoeV2VBSComputeTilingData vbsComputeParamsOp;
InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp;
InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp;
InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp;
};
class InnerMoeInitRoutingV2TilingBase : public TilingBaseClass {
protected:
bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) override;
bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override;
bool DoOpTiling() override;
uint64_t GetTilingKey() const override;
bool GetWorkspaceSize() override;
protected:
bool CheckTokenCount(int64_t num, const char* tag);
virtual void Tiling4GatherOutCompute() = 0;
void Tiling4SrcToDstCompute();
virtual void Tiling4SrcToDstCapacityCompute();
void Tiling4SortOutCompute();
void Tiling4VMSMiddleCompute();
void Tiling4VBSCompute();
void ShowTilingData();
void Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData);
void Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData);
virtual bool IsFullLoad() = 0;
int64_t aivNum = 0;
int64_t sortLoopMaxElement = 0;
int64_t mrgSortListMaxElement = 2040;
int64_t totalLength = 0;
int64_t activateNum = 0;
int64_t expertCapacity = 0;
int64_t expertNum = 0;
int64_t dropPadMode = 0;
int64_t expertTokensCountOrCumsumFlag = 0;
bool expertTokensBeforeCapacityFlag = false;
int64_t inuptXDtypeSize_ = 0;
bool isFullLoad = false;
InnerMoeInitRoutingV2TilingData moeInitRoutingTilingData;
};
bool InnerMoeInitRoutingV2TilingBase::DoOpTiling() {
sortLoopMaxElement =
(aicoreParams_.ubSize) / (sizeof(int32_t) * NUM_TWO * NUM_FOUR) / SORT32_ALIGN_ELEMENT * SORT32_ALIGN_ELEMENT;
isFullLoad = IsFullLoad();
Tiling4VBSCompute();
Tiling4VMSMiddleCompute();
Tiling4SortOutCompute();
Tiling4SrcToDstCompute();
Tiling4SrcToDstCapacityCompute();
Tiling4GatherOutCompute();
return true;
};
uint64_t InnerMoeInitRoutingV2TilingBase::GetTilingKey() const {
if (isFullLoad) {
return TILING_KEY_HIGH_PERFORMANCE;
}
if (dropPadMode == 0) {
if (totalLength <= sortLoopMaxElement) { // 排序只用到一个核排序
return TILING_KEY_DROPLESS_SORT_ONE_CORE;
} else {
return TILING_KEY_DROPLESS_SORT_MULTI_CORE;
}
} else {
if (totalLength <= sortLoopMaxElement) {
return TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE;
} else {
return TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE;
}
}
return tilingKey_;
}
bool InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activateNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) {
this->activateNum = activateNum;
this->expertCapacity = expertCapacity;
this->expertNum = expertNum;
this->dropPadMode = dropPadMode;
this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
if (dropPadMode == 1) {
// droppad场景下不输出expertTokensCountOrCumsum
expertTokensCountOrCumsumFlag = 0;
} else {
// dropless场景下不输出expertTokensBeforeCapacity
expertTokensBeforeCapacityFlag = false;
}
moeInitRoutingTilingData.cols = cols;
moeInitRoutingTilingData.n = m;
moeInitRoutingTilingData.k = topK;
moeInitRoutingTilingData.expertCapacity = expertCapacity;
moeInitRoutingTilingData.expertNum = expertNum;
moeInitRoutingTilingData.dropPadMode = dropPadMode;
moeInitRoutingTilingData.expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
moeInitRoutingTilingData.expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
totalLength = moeInitRoutingTilingData.n * moeInitRoutingTilingData.k;
inuptXDtypeSize_ = inuptXDtypeSize;
return true;
}
bool InnerMoeInitRoutingV2TilingBase::GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) {
aivNum = aivCoreNum;
aicoreParams_.blockDim = aivCoreNum;
aicoreParams_.ubSize = ubSizePlatForm;
moeInitRoutingTilingData.coreNum = aivCoreNum;
return true;
}
bool InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize() {
// 计算workspace大小
size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; // 排序需要的空间
size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO;
size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t);
workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024;
return true;
}
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
tilingData->needCoreNum = 1;
tilingData->perCoreElements = totalLength;
tilingData->perCoreLoops = 1;
tilingData->perCorePerLoopElements = tilingData->perCoreElements;
tilingData->perCoreLastLoopElements = tilingData->perCoreElements;
tilingData->lastCoreElements = tilingData->perCoreElements;
tilingData->lastCoreLoops = 1;
tilingData->lastCorePerLoopElements = tilingData->perCoreElements;
tilingData->lastCoreLastLoopElements = tilingData->perCoreElements;
}
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
//Tiling4VBSMultiCoreCompute
int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); // 向上取整
needCoreNum = static_cast<int64_t>(std::pow(4, CeilLog4(needCoreNum)));
needCoreNum = std::min(needCoreNum, aivNum); // 不能超过物理核数
if (needCoreNum > 0) {
int64_t perCoreElements = totalLength / needCoreNum; // 每个核处理的元素数
int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT;
int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements;
int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT;
if (lastCoreElement > alineCeilPerCoreElements) {
perCoreElements = alineCeilPerCoreElements;
needCoreNum = CeilDiv(totalLength, perCoreElements);
} else {
perCoreElements = alineFloorPerCoreElements;
}
tilingData->needCoreNum = needCoreNum;
do {
tilingData->perCoreElements = perCoreElements;
tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // 每个核处理的loop数
tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement);
tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements;
tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements;
tilingData->lastCoreLoops = tilingData->perCoreLoops;
int64_t tmp = CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops);
int64_t lastCorePerLoopElements =
CeilDiv(CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops), SORT32_ALIGN_ELEMENT) *
SORT32_ALIGN_ELEMENT;
tilingData->lastCorePerLoopElements = lastCorePerLoopElements;
tilingData->lastCoreLastLoopElements = tilingData-> lastCoreElements - (tilingData->lastCoreLoops - 1) * tilingData->lastCorePerLoopElements;
perCoreElements -= SORT32_ALIGN_ELEMENT;
} while (tilingData->lastCoreLastLoopElements <= 0 && perCoreElements > 0);
}
}
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() {
auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
tilingData->oneLoopMaxElements = sortLoopMaxElement;
if (totalLength <= sortLoopMaxElement) { // 只用到一个核
Tiling4VBSOneCoreCompute(tilingData);
return;
}
Tiling4VBSMultiCoreCompute(tilingData);
}
void InnerMoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() {
auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp;
if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { // 队列数小于一次vms则没有中间归并
tilingData->needCoreNum = 0; // 需要的核数
} else {
int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM);
tilingData->needCoreNum = needCoreNum; // 需要的核数
}
}
void InnerMoeInitRoutingV2TilingBase::Tiling4SortOutCompute() {
auto tilingData = &moeInitRoutingTilingData.sortOutComputeParamsOp;
tilingData->oneLoopMaxElements = mrgSortListMaxElement;
}
void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCompute() {
auto tilingData = &moeInitRoutingTilingData.srcToDstComputeParamsOp;
int64_t perLoopMaxRows = (aicoreParams_.ubSize - ASSIST_NUM * sizeof(float) - aivNum * SORT32_ALIGN_ELEMENT) /
(SORT32_ALIGN_ELEMENT * NUM_TWO) / NUM_TWO;
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
if (perCoreRows <= 0) {
tilingData->needCoreNum = 0;
return;
}
int64_t needCoreNum = CeilDiv(totalLength, perCoreRows);
tilingData->needCoreNum = needCoreNum;
int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
tilingData->perCoreRows = perCoreRows;
if (perLoopMaxRows >= tilingData->perCoreRows) { // 一个loop结束
tilingData->perCorePerLoopRows = tilingData->perCoreRows;
tilingData->perCoreLastLoopRows = tilingData->perCoreRows;
} else {
tilingData->perCorePerLoopRows = perLoopMaxRows;
tilingData->perCoreLastLoopRows = tilingData->perCoreRows - (CeilDiv(tilingData->perCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows;
}
tilingData->lastCoreRows = lastCoreNum;
if (perLoopMaxRows >= tilingData->lastCoreRows) {
tilingData->lastCorePerLoopRows = tilingData->lastCoreRows;
tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows;
} else {
tilingData->lastCorePerLoopRows = perLoopMaxRows;
tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows - (CeilDiv(tilingData->lastCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows;
}
}
void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute() {
auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp;
int64_t perCoreRows = CeilDiv(totalLength, aivNum);
if (perCoreRows <= 0) {
tilingData->needCoreNum = 0;
return;
}
int64_t needCoreNum = CeilDiv(totalLength, perCoreRows);
tilingData->needCoreNum = needCoreNum;
int64_t cols = moeInitRoutingTilingData.cols;
tilingData->perCoreRows = perCoreRows;
int64_t lastCoreRows = totalLength - perCoreRows * (needCoreNum - 1);
tilingData->lastCoreRows = lastCoreRows;
int64_t rowSize =
(perCoreRows * sizeof(int32_t) * 2 + ONE_BLOCK_BYTE + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
int64_t colSize = (cols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
if (rowSize + colSize < static_cast<int64_t>(aicoreParams_.ubSize)) {
tilingData->perCorePerLoopRows = perCoreRows;
tilingData->perCoreLastLoopRows = perCoreRows;
tilingData->lastCorePerLoopRows = lastCoreRows;
tilingData->lastCoreLastLoopRows = lastCoreRows;
tilingData->perCoreLoops = 1;
tilingData->lastCoreLoops = 1;
tilingData->perLoopCols = cols;
tilingData->lastLoopCols = cols;
tilingData->colLoops = 1;
} else {
int64_t baseMaxCols = MAX_COLS_ONE_LOOP;
int64_t baseMaxColsSize = (baseMaxCols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
int64_t basePerLoopMaxRows = (static_cast<int64_t>(aicoreParams_.ubSize) - baseMaxColsSize - ONE_BLOCK_BYTE) /
static_cast<int64_t>(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
if (cols < MAX_COLS_ONE_LOOP) {
basePerLoopMaxRows = (static_cast<int64_t>(aicoreParams_.ubSize) - colSize - ONE_BLOCK_BYTE) /
static_cast<int64_t>(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
} else if (perCoreRows < basePerLoopMaxRows) {
baseMaxCols =
(static_cast<int64_t>(aicoreParams_.ubSize) - rowSize) / inuptXDtypeSize_ / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE;
}
tilingData->perLoopCols = (std::min(baseMaxCols, cols));
tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols));
tilingData->colLoops = ((cols + baseMaxCols - 1) / baseMaxCols);
tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows));
tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows));
tilingData->perCoreLoops = ((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows));
tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows));
tilingData->lastCoreLoops = ((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows);
}
}
}

View File

@@ -0,0 +1,94 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_common.h
* \brief
*/
#ifndef INNER_MOE_V2_COMMON_H
#define INNER_MOE_V2_COMMON_H
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
constexpr int64_t SPLIT_N = 0;
constexpr int64_t SPLIT_K = 1;
constexpr float MIN_FP32 = -3.4e38;
constexpr int64_t ONE_REPEAT_SORT_NUM = 32;
constexpr int64_t BLOCK_BYTES = 32;
constexpr int64_t INT32_ONE_BLOCK_NUM = 8;
constexpr int64_t ASSIST_NUM = 256;
constexpr int64_t ASSIST_INDEX_NUM = 32;
constexpr int64_t MERGE_LIST_TWO = 2;
constexpr int64_t MERGE_LIST_THREE = 3;
constexpr int64_t MERGE_LIST_FOUR = 4;
constexpr int64_t MERGE_LIST_IDX_TWO = 2;
constexpr int64_t MERGE_LIST_IDX_THREE = 3;
constexpr int64_t MAX_EXPERT_NUM = 5120;
constexpr int64_t DROPLESS_MODE = 0;
constexpr int64_t DROP_PAD_MODE = 1;
constexpr int64_t EXERPT_TOKENS_COUNT = 2;
constexpr int64_t EXERPT_TOKENS_CUMSUM = 1;
constexpr int64_t EXERPT_TOKENS_NONE = 0;
constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1;
const __gm__ int32_t assist[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,
4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0,
8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0,
12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0,
16, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0,
20, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0,
24, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0, 0,
28, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0};
__aicore__ inline int64_t Ceil(int64_t a, int64_t b) {
if (b == 0) {
return 0;
}
return (a + b - 1) / b;
}
__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) {
if (bytes == 0) {
return 0;
}
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes;
}
__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) {
return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES;
}
template <typename T>
__aicore__ inline T Min(T a, T b) {
return a > b ? b : a;
}
template <typename T>
__aicore__ inline T Max(T a, T b) {
return a < b ? b : a;
}
template <HardEvent event>
__aicore__ inline void SetWaitFlag(HardEvent evt) {
event_t eventId = static_cast<event_t>(GetTPipePtr()->FetchEventID(evt));
SetFlag<event>(eventId);
WaitFlag<event>(eventId);
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_COMMON_H

View File

@@ -0,0 +1,310 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_expert_token_out.h
* \brief
*/
#ifndef INNER_MOE_V2_EXPERT_TOKEN_OUT_H
#define INNER_MOE_V2_EXPERT_TOKEN_OUT_H
#include "moe_v2_common.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
constexpr int64_t EXPERT_ID_VALUE_NUM = 2;
class MoeV2ExpertTokenOut {
public:
__aicore__ inline MoeV2ExpertTokenOut(){};
template <typename TilingData>
__aicore__ inline void Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
GM_ADDR expandedRowIdx, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t progress);
__aicore__ inline void Compute(int64_t progress);
__aicore__ inline void SyncAll();
__aicore__ inline void InitLocal();
__aicore__ inline void GetExpertTokenCount(int32_t curExpertId);
__aicore__ inline void CopyOutTokenGm();
__aicore__ inline void CopyOutExpertTokensCumsum(bool isTail);
__aicore__ inline void CopyOutExpertTokensCount(bool isTail);
private:
TPipe* pipe;
TQue<QuePosition::VECIN, 1> copyInQueue;
TQue<QuePosition::VECIN, 1> expertTokenIdxCopyInQueue;
TQue<QuePosition::VECOUT, 1> expertTokenIdxCopyOutQueue;
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
GlobalTensor<int32_t> expandedExpertIdxGm;
GlobalTensor<int32_t> expertIdxValueGm;
GlobalTensor<int32_t> expandedRowIdxGm;
LocalTensor<int32_t> expertTokenIdxOutLocal;
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
int64_t coreNum;
int64_t blockIdx;
int64_t totalLength;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t expertNum;
int64_t expertNumUbAlign;
int64_t dropPadMode = 0;
int64_t expertTokensCountOrCumsumFlag = 0;
int64_t expertTokensBeforeCapacityFlag = 0;
int64_t tokenCount = 0;
int64_t expertIdx = 0;
int32_t lastExpertId = -1;
int32_t firstExpertId = -1;
int32_t expertTokenValue = 0;
};
__aicore__ inline void MoeV2ExpertTokenOut::InitLocal() {
LocalTensor<int32_t> tokenIdxLocal = expertTokenIdxCopyOutQueue.AllocTensor<int32_t>();
Duplicate<int32_t>(tokenIdxLocal, 0, this->expertNumUbAlign);
expertTokenIdxCopyOutQueue.EnQue<int32_t>(tokenIdxLocal);
// expandedRowIdx initialized to -1, which is used in the src_to_dst_with_capacity step.
// use this step SyncAll to synchronize every core data
if (this->dropPadMode == 0) {
return;
}
LocalTensor<int32_t> outLocal = copyInQueue.AllocTensor<int32_t>();
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
Duplicate<int32_t>(outLocal, -1, perLoopRows);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
for (int64_t loop = 0; loop < loops; loop++) {
int64_t copyLength = perLoopRows;
if (loop == loops - 1) {
copyLength = lastLoopRows;
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0,
0};
DataCopyPad(expandedRowIdxGm[this->blockIdx * this->srcToDstTilingData->perCoreRows + loop * perLoopRows], outLocal,
copyParams);
}
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
copyInQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2ExpertTokenOut::CopyIn(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
DataCopy(inLocal, expandedExpertIdxGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t)));
copyInQueue.EnQue<int32_t>(inLocal);
}
__aicore__ inline void MoeV2ExpertTokenOut::GetExpertTokenCount(int32_t curExpertId) {
this->tokenCount++;
if (this->lastExpertId < curExpertId) {
this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount - 1);
this->tokenCount = 1;
this->expertIdx += (curExpertId - this->lastExpertId);
while (curExpertId - this->firstExpertId + 1 > this->expertNumUbAlign) {
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
CopyOutExpertTokensCumsum(false);
CopyOutExpertTokensCount(false);
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
Duplicate<int32_t>(this->expertTokenIdxOutLocal, 0, this->expertNumUbAlign);
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
this->firstExpertId += this->expertNumUbAlign;
this->expertIdx = curExpertId - this->firstExpertId;
}
this->lastExpertId = curExpertId;
}
}
__aicore__ inline void MoeV2ExpertTokenOut::Compute(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
if (this->lastExpertId == -1) {
this->lastExpertId = inLocal.GetValue(0);
this->firstExpertId = this->lastExpertId;
}
for (int64_t i = 0; i < currentLoopRows; i++) {
int32_t expertId = inLocal.GetValue(i);
GetExpertTokenCount(expertId);
}
this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount);
copyInQueue.FreeTensor(inLocal);
}
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCumsum(bool isTail) {
if (this->dropPadMode != DROPLESS_MODE || expertTokensCountOrCumsumFlag != EXERPT_TOKENS_CUMSUM) {
return;
}
#ifdef __CCE_KT_TEST__
// CPU孪生调试无法使用多核同步可能导致index为未初始化的脏数据因此需要特殊处理
if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) {
return;
}
#endif
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
int64_t end = this->expertNum - this->firstExpertId;
for (int64_t i = 0; i < copyLength; i++) {
this->expertTokenValue += this->expertTokenIdxOutLocal.GetValue(i);
this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue);
}
// if the remianing UB is sufficient, use the UB space to copy
// otherwise, copy the calculated data first, and then copy the last tokenValue to remaining expert position
if (isTail && end <= this->expertNumUbAlign) {
int64_t startAlign = Min(Align(copyLength, sizeof(int32_t)), end);
for (int64_t i = copyLength; i < startAlign; i++) {
this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue);
}
if (startAlign < end) {
Duplicate<int32_t>(this->expertTokenIdxOutLocal[startAlign], this->expertTokenValue, end - startAlign);
}
copyLength = end;
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
SetAtomicAdd<int32_t>();
#ifndef __CCE_KT_TEST__
DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
#endif
SetAtomicNone();
if (isTail && end > this->expertNumUbAlign) {
int64_t remainderLength = end - copyLength;
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
Duplicate<int32_t>(this->expertTokenIdxOutLocal, this->expertTokenValue, this->expertNumUbAlign);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
int64_t loopTimes = remainderLength / this->expertNumUbAlign + 1;
for (int64_t i = 0; i < loopTimes; i++) {
copyLength = i == loopTimes - 1 ? remainderLength - this->expertNumUbAlign * i : this->expertNumUbAlign;
DataCopyExtParams params{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
SetAtomicAdd<int32_t>();
DataCopyPad(expertTokensCountOrCumsumGm[this->lastExpertId + 1 + this->expertNumUbAlign * i],
this->expertTokenIdxOutLocal, params);
SetAtomicNone();
}
}
}
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCount(bool isTail) {
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
#ifdef __CCE_KT_TEST__
// CPU孪生调试不进行输出拷贝
return;
#endif
SetAtomicAdd<int32_t>();
if (this->dropPadMode == DROP_PAD_MODE && expertTokensBeforeCapacityFlag > EXERPT_TOKENS_NONE) {
DataCopyPad(expertTokensBeforeCapacityGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
}
if (this->dropPadMode == DROPLESS_MODE && expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams);
}
SetAtomicNone();
}
__aicore__ inline void MoeV2ExpertTokenOut::CopyOutTokenGm() {
if (this->dropPadMode == DROPLESS_MODE) {
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
CopyOutExpertTokensCumsum(true);
CopyOutExpertTokensCount(true);
return;
}
this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign, this->lastExpertId);
this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign + 1, this->tokenCount);
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(EXPERT_ID_VALUE_NUM * sizeof(int32_t)),
0, 0, 0};
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(expertIdxValueGm[this->blockIdx * EXPERT_ID_VALUE_NUM],
this->expertTokenIdxOutLocal[this->expertNumUbAlign], copyParams);
CopyOutExpertTokensCount(true);
}
__aicore__ inline void MoeV2ExpertTokenOut::SyncAll() {
if (coreNum == 1) {
return;
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
template <typename TilingData>
__aicore__ inline void MoeV2ExpertTokenOut::Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
GM_ADDR expandedRowIdx, GM_ADDR workspace,
const TilingData* tilingData, TPipe* tPipe) {
int64_t blockNum = GetBlockNum();
this->pipe = tPipe;
//this->blockIdx = GetBlockIdx();
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->coreNum = tilingData->coreNum;
this->totalLength = tilingData->n * tilingData->k;
this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp);
this->expertNum = tilingData->expertNum;
this->dropPadMode = tilingData->dropPadMode;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
this->coreRows = this->srcToDstTilingData->lastCoreRows;
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
} else {
this->coreRows = this->srcToDstTilingData->perCoreRows;
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
}
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, Align(this->totalLength, sizeof(int32_t)));
if (this->dropPadMode == DROPLESS_MODE && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, this->expertNum);
}
if (this->dropPadMode == DROP_PAD_MODE && this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity, this->expertNum);
}
expandedExpertIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2,
this->coreNum * 2);
this->expertNumUbAlign = Min(Align(this->expertNum, sizeof(int32_t)), MAX_EXPERT_NUM);
pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES);
pipe->InitBuffer(expertTokenIdxCopyInQueue, 1, this->expertNumUbAlign * sizeof(int32_t));
pipe->InitBuffer(expertTokenIdxCopyOutQueue, 1, (this->expertNumUbAlign + EXPERT_ID_VALUE_NUM) * sizeof(int32_t));
}
__aicore__ inline void MoeV2ExpertTokenOut::Process() {
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
currentLoopRows = perLoopRows;
InitLocal();
this->expertTokenIdxOutLocal = expertTokenIdxCopyOutQueue.DeQue<int32_t>();
for (int64_t loop = 0; loop < loops - 1; loop++) {
CopyIn(loop);
Compute(loop);
}
currentLoopRows = lastLoopRows;
CopyIn(loops - 1);
Compute(loops - 1);
CopyOutTokenGm();
expertTokenIdxCopyOutQueue.FreeTensor(this->expertTokenIdxOutLocal);
}
this->SyncAll();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_EXPERT_TOKEN_OUT_H

View File

@@ -0,0 +1,468 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_v2_fullload_dynamic_quant.h
* \brief
*/
#ifndef MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H
#define MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H
#include "moe_v2_mrgsort.h"
#include "moe_v2_sort_base.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T>
class MoeV2FullLoadDynamicQuant : public MoeV2SortBase {
public:
__aicore__ inline MoeV2FullLoadDynamicQuant(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale,
GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn();
__aicore__ inline void SortCompute();
__aicore__ inline void CopyOutIdx();
__aicore__ inline void CopyOutEmpty();
__aicore__ inline void CopyOutXQuant1H();
__aicore__ inline void CopyOutXQuantEH();
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
private:
int64_t sortNum_;
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_;
int64_t blockIdx_;
int64_t needCoreNum_;
int64_t coreRows_;
int64_t perCoreRows_;
int64_t k_;
int64_t n_;
int64_t cols_;
int64_t activateRows_;
int64_t expertNum;
int64_t expertCapacity;
int64_t smoothType;
int64_t colsAlign;
TQue<QuePosition::VECIN, 1> xCopyInQueue_;
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue_;
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue_;
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue_;
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue_;
TQue<QuePosition::VECIN, 1> smoothInQueue;
TQue<QuePosition::VECOUT, 1> calcQueue;
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
GlobalTensor<T> xGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<float> quantSmoothGm;
GlobalTensor<float> dynamicQuantScaleGm;
GlobalTensor<int8_t> expandedXGm_;
GlobalTensor<int32_t> expandedRowIdxGm_;
GlobalTensor<int32_t> expandedExpertIdxGm_;
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
int64_t expertTokensCountOrCumsumFlag = 0;
int64_t expertTokensBeforeCapacityFlag = 0;
int64_t dropPadMode = 0;
LocalTensor<uint32_t> expandDstToSrcRowLocal;
LocalTensor<int32_t> expandedExpertIdxLocal;
};
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyIn() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams);
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
sortDataCopyInQueue.EnQue(inLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::SortCompute() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
LocalTensor<float> concatLocal;
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast<uint32_t>();
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor<float>();
expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor<uint32_t>();
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
this->totalLength);
pipe_barrier(PIPE_V);
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
expandedExpertIdxCopyOutQueue_.EnQue<int32_t>(expandedExpertIdxLocalInt32);
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor<uint32_t>();
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
pipe_barrier(PIPE_V);
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
expandedRowIdxCopyOutQueue_.EnQue<uint32_t>(expandedRowIdx);
sortDataCopyInQueue.FreeTensor(inLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutIdx() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = this->totalLength * sizeof(int32_t);
DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams);
expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::ComputeExpertTokenCountOrCumsum() {
expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue_.AllocTensor<int32_t>();
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
Duplicate(expertTokensCount, 0, expertNumAlign);
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
int32_t lastExpertId = expandedExpertIdxLocal.GetValue(0);
int64_t tokenCount = 0;
int64_t lastExpertCount = 0;
for (int64_t i = 0; i < this->totalLength; i++) {
int32_t curExpertId = expandedExpertIdxLocal.GetValue(i);
tokenCount++;
while (lastExpertId < curExpertId) {
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
tokenCount = 1;
}
lastExpertId++;
}
}
#ifndef __CCE_KT_TEST__
expertTokensCount.SetValue(lastExpertId, tokenCount);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
lastExpertId++;
while (lastExpertId < this->expertNum) {
expertTokensCount.SetValue(lastExpertId, tokenCount);
lastExpertId++;
}
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
0};
if (this->expertTokensCountOrCumsumFlag > 0) {
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
}
expertTokensCopyOutQueue_.FreeTensor(expertTokensCount);
#endif
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutEmpty() {
expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Compute(LocalTensor<float>& smoothLocal) {
LocalTensor<float> inLocal = xCopyInQueue_.DeQue<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[colsAlign], RoundMode::CAST_NONE, this->cols_);
pipe_barrier(PIPE_V);
}
if (smoothType != 0) {
Mul(inLocal, inLocal, smoothLocal, this->cols_);
pipe_barrier(PIPE_V);
}
Abs(tempLocal, inLocal, this->cols_);
pipe_barrier(PIPE_V);
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols_);
pipe_barrier(PIPE_V);
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
Duplicate<float>(tempLocal, maxValue, this->cols_);
pipe_barrier(PIPE_V);
Div(tempLocal, inLocal, tempLocal, this->cols_);
pipe_barrier(PIPE_V);
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols_);
pipe_barrier(PIPE_V);
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols_);
calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuant1H() {
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
int64_t startXRow = curRowsStart / this->k_;
int64_t endXRow = curRowsEnd / this->k_;
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
LocalTensor<float> smoothLocal;
if (smoothType == 1) {
smoothLocal = smoothInQueue.AllocTensor<float>();
DataCopyPad(smoothLocal, quantSmoothGm, smoothCopyParams, {false, 0, 0, 0});
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
}
for (int64_t row = startXRow; row <= endXRow; row++) {
LocalTensor<T> xLocal = xCopyInQueue_.AllocTensor<T>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(xLocal, xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0});
} else {
DataCopyPad(xLocal[colsAlign], xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0});
}
xCopyInQueue_.EnQue<T>(xLocal);
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) {
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
curRowsStart++;
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) {
continue;
}
DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams);
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
}
xCopyInQueue_.FreeTensor(xLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
}
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::CopyOutXQuantEH() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
Muls(expandDstToSrcRowLocal.ReinterpretCast<float>(), expandDstToSrcRowLocal.ReinterpretCast<float>(), (float)-1,
this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast<int32_t>();
Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast<float>(), RoundMode::CAST_ROUND, this->totalLength);
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1;
DataCopyExtParams dataXCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothCopyParams{1, static_cast<uint32_t>(this->cols_ * sizeof(float)), 0, 0, 0};
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols_ * sizeof(int8_t)), 0, 0, 0};
for (int64_t row = curRowsStart; row <= curRowsEnd; row++) {
if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) {
break;
}
int32_t srcIdx = sortedRowIdx.GetValue(row);
int32_t expertIdx = expandedExpertIdxLocal.GetValue(row);
LocalTensor<T> inLocal = xCopyInQueue_.AllocTensor<T>();
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0});
}
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0});
xCopyInQueue_.EnQue<T>(inLocal);
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0});
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams);
xCopyInQueue_.FreeTensor(inLocal);
smoothInQueue.FreeTensor(smoothLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale,
GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData,
TPipe* tPipe) {
this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp);
//this->blockIdx_ = GetBlockIdx();
this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num();
this->k_ = tilingData->k;
this->n_ = tilingData->n;
this->cols_ = tilingData->cols;
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
this->activateRows_ = this->gatherOutTilingData_->activateRows;
if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) {
this->coreRows_ = this->gatherOutTilingData_->lastCoreRows;
} else {
this->coreRows_ = this->gatherOutTilingData_->perCoreRows;
}
this->expertNum = tilingData->expertNum;
this->dropPadMode = tilingData->dropPadMode;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
this->totalLength = tilingData->n * tilingData->k;
this->smoothType = tilingData->smoothType;
this->colsAlign = Align(this->cols_, sizeof(T));
this->pipe = tPipe;
xGm_.SetGlobalBuffer((__gm__ T*)x);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
expandedXGm_.SetGlobalBuffer((__gm__ int8_t*)expandedX);
expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
if (this->expertTokensCountOrCumsumFlag > 0) {
// dropless
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
Align(this->expertNum, sizeof(int32_t)));
}
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
int64_t kvFactor = 2;
int64_t buffSize = this->sortNum_ * sizeof(int32_t);
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t startXRow = curRowsStart / this->k_;
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize);
pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize);
pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize);
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
if constexpr (IsSameType<T, float>::value) {
pipe->InitBuffer(xCopyInQueue_, 1, AlignBytes(this->cols_, sizeof(float)));
} else {
pipe->InitBuffer(xCopyInQueue_, 1, 2 * AlignBytes(this->cols_, sizeof(T)));
}
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadDynamicQuant<T>::Process() {
if (this->blockIdx_ < this->needCoreNum_) {
CopyIn();
SortCompute();
if (this->blockIdx_ == 0) {
CopyOutIdx();
}
if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
ComputeExpertTokenCountOrCumsum();
} else {
CopyOutEmpty();
}
if (smoothType == 2) {
CopyOutXQuantEH();
} else {
CopyOutXQuant1H();
}
}
}
} // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_DYNAMIC_QUANT_FULL_LOAD_H

View File

@@ -0,0 +1,155 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_v2_fullload_quant.h
* \brief
*/
#ifndef MOE_V2_FULL_LOAD_QUANT_H
#define MOE_V2_FULL_LOAD_QUANT_H
#include "moe_v2_fullload_quant_base.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T>
class MoeV2FullLoadQuant : public MoeV2FullLoadQuantBase {
public:
__aicore__ inline MoeV2FullLoadQuant(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX,
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void Compute(int64_t xLocalLength);
__aicore__ inline void CopyOutX();
private:
TQue<QuePosition::VECOUT, 1> floatQueue;
TQue<QuePosition::VECOUT, 1> halfQueue;
TQue<QuePosition::VECOUT, 1> inputXCopyOutQueue;
GlobalTensor<T> xGm;
GlobalTensor<float> scaleGm;
GlobalTensor<float> offsetGm;
float scale;
float offset;
};
template <typename T>
__aicore__ inline void MoeV2FullLoadQuant<T>::Compute(int64_t xLocalLength) {
LocalTensor<T> inLocal = xCopyInQueue.DeQue<T>();
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.AllocTensor<int8_t>();
LocalTensor<float> floatLocal = floatQueue.AllocTensor<float>();
LocalTensor<half> halfLocal = halfQueue.AllocTensor<half>();
uint32_t elements = Align(this->cols, sizeof(int8_t)) * xLocalLength;
if constexpr (IsSameType<T, bfloat16_t>::value) {
Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> intLocal = floatLocal.ReinterpretCast<int32_t>();
Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements);
pipe_barrier(PIPE_V);
SetDeqScale((half)1.000000e+00f);
pipe_barrier(PIPE_V);
Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements);
pipe_barrier(PIPE_V);
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
} else if constexpr (IsSameType<T, float>::value) {
Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
pipe_barrier(PIPE_V);
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
} else {
Muls(inLocal, inLocal, static_cast<T>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(inLocal, inLocal, static_cast<T>(this->offset), elements);
pipe_barrier(PIPE_V);
Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements);
}
inputXCopyOutQueue.EnQue(outLocal);
xCopyInQueue.FreeTensor(inLocal);
floatQueue.FreeTensor(floatLocal);
halfQueue.FreeTensor(halfLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadQuant<T>::CopyOutX() {
LocalTensor<T> xLocal = xCopyInQueue.AllocTensor<T>();
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue<int32_t>();
int64_t inFactor = Align(this->cols, sizeof(int8_t));
int64_t curRowsStart = this->blockIdx * this->perCoreRows;
int64_t startXRow = curRowsStart / this->k;
int64_t endXRow = (curRowsStart + this->coreRows - 1) / this->k;
uint32_t dstStride = (inFactor * sizeof(T) - AlignBytes(this->cols, sizeof(T))) / BLOCK_BYTES;
DataCopyExtParams dataXCopyParams{static_cast<uint16_t>(endXRow - startXRow + 1),
static_cast<uint32_t>(this->cols * sizeof(T)), 0, dstStride, 0};
DataCopyPadExtParams<T> dataXCopyPadParams{false, 0, 0, 0};
DataCopyPad(xLocal, xGm[startXRow * this->cols], dataXCopyParams, dataXCopyPadParams);
xCopyInQueue.EnQue(xLocal);
Compute(endXRow - startXRow + 1);
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.DeQue<int8_t>();
int64_t k = 0;
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
for (int64_t i = startXRow; i <= endXRow; i++) {
for (; k < this->perCoreRows && curRowsStart / this->k == i; curRowsStart++, k++) {
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
if (outIndex < this->activateRows) {
DataCopyPad(expandedXGm[outIndex * this->cols], outLocal[(i - startXRow) * inFactor], intriParams);
}
}
}
expandedRowIdxCopyOutQueue.FreeTensor(expandedRowIdx);
inputXCopyOutQueue.FreeTensor(outLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadQuant<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset,
GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) {
this->InitBase(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe);
xGm.SetGlobalBuffer((__gm__ T*)x);
scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1);
offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1);
this->scale = scaleGm.GetValue(0);
this->offset = offsetGm.GetValue(0);
int64_t curRowsStart = this->blockIdx * this->perCoreRows;
int64_t rowLength = (curRowsStart + this->coreRows - 1) / this->k - curRowsStart / this->k + 1;
int64_t xAlignedCount = Align(this->cols, sizeof(int8_t));
pipe->InitBuffer(xCopyInQueue, bufferNum, xAlignedCount * sizeof(T) * rowLength);
pipe->InitBuffer(inputXCopyOutQueue, 1, xAlignedCount * sizeof(int8_t) * rowLength);
pipe->InitBuffer(floatQueue, 1, xAlignedCount * sizeof(float) * rowLength);
pipe->InitBuffer(halfQueue, 1, xAlignedCount * sizeof(half) * rowLength);
}
template <typename T>
__aicore__ inline void MoeV2FullLoadQuant<T>::Process() {
if (this->blockIdx < this->needCoreNum) {
this->ProcessBase();
CopyOutX();
}
}
} // namespace MoeInitRoutingQuantV2
#endif

View File

@@ -0,0 +1,279 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_v2_fullload_quant_base.h
* \brief
*/
#ifndef MOE_V2_FULL_LOAD_QUANT_BASE_H
#define MOE_V2_FULL_LOAD_QUANT_BASE_H
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2FullLoadQuantBase {
public:
__aicore__ inline MoeV2FullLoadQuantBase(){};
protected:
__aicore__ inline void InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void ProcessBase();
__aicore__ inline void CopyIn();
__aicore__ inline void SortCompute();
__aicore__ inline void CopyOutIdx();
__aicore__ inline void CopyOutEmpty();
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
protected:
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
TPipe* pipe;
int64_t tileLength;
int64_t bufferNum = 1;
int64_t totalLength;
int64_t coreNum;
int64_t sortNum;
int64_t blockIdx;
int64_t needCoreNum;
int64_t coreRows;
int64_t perCoreRows;
int64_t k;
int64_t n;
int64_t cols;
int64_t activateRows;
int64_t expertNum;
int64_t expertCapacity;
TQue<QuePosition::VECIN, 1> sortDataCopyInQueue;
TBuf<TPosition::VECCALC> tempBuffer;
TBuf<TPosition::VECCALC> sortedBuffer;
TQue<QuePosition::VECIN, 1> xCopyInQueue;
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue;
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue;
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue;
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue;
GlobalTensor<int32_t> expertIdxGm;
GlobalTensor<int8_t> expandedXGm;
GlobalTensor<int32_t> expandedRowIdxGm;
GlobalTensor<int32_t> expandedExpertIdxGm;
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
int64_t expertTokensCountOrCumsumFlag = 0;
int64_t expertTokensBeforeCapacityFlag = 0;
int64_t dropPadMode = 0;
static constexpr int64_t DST_BLK_STRIDE = 1;
static constexpr int64_t DST_REP_STRIDE = 8;
static constexpr int64_t FOUR_BLOCK_BYTES = 128;
};
__aicore__ inline void MoeV2FullLoadQuantBase::CopyIn() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams);
ArithProgression<int32_t>(inLocal[this->sortNum], 0, 1, this->totalLength);
sortDataCopyInQueue.EnQue(inLocal);
}
__aicore__ inline void MoeV2FullLoadQuantBase::SortCompute() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
LocalTensor<float> concatLocal;
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum));
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum].template ReinterpretCast<uint32_t>();
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum));
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue.AllocTensor<float>();
LocalTensor<uint32_t> expandDstToSrcRowLocal = expandDstToSrcRowQueue.AllocTensor<uint32_t>();
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
this->totalLength);
pipe_barrier(PIPE_V);
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
expandedExpertIdxCopyOutQueue.EnQue<int32_t>(expandedExpertIdxLocalInt32);
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.AllocTensor<uint32_t>();
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
ArithProgression<int32_t>(inLocal[this->sortNum], 0, 1, this->totalLength);
pipe_barrier(PIPE_V);
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
expandedRowIdxCopyOutQueue.EnQue<uint32_t>(expandedRowIdx);
sortDataCopyInQueue.FreeTensor(inLocal);
expandDstToSrcRowQueue.FreeTensor(expandDstToSrcRowLocal);
}
__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutIdx() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue<int32_t>();
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = this->totalLength * sizeof(int32_t);
DataCopyPad(expandedRowIdxGm, expandedRowIdx, intriParams);
expandedRowIdxCopyOutQueue.EnQue(expandedRowIdx);
}
__aicore__ inline void MoeV2FullLoadQuantBase::ComputeExpertTokenCountOrCumsum() {
LocalTensor<int32_t> expandedExpertIdx = expandedExpertIdxCopyOutQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue.AllocTensor<int32_t>();
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
Duplicate(expertTokensCount, 0, expertNumAlign);
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
int32_t lastExpertId = expandedExpertIdx.GetValue(0);
int64_t tokenCount = 0;
int64_t lastExpertCount = 0;
for (int64_t i = 0; i < this->totalLength; i++) {
int32_t curExpertId = expandedExpertIdx.GetValue(i);
tokenCount++;
while (lastExpertId < curExpertId) {
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
tokenCount = 1;
}
lastExpertId++;
}
}
expertTokensCount.SetValue(lastExpertId, tokenCount);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
lastExpertId++;
while (lastExpertId < this->expertNum) {
expertTokensCount.SetValue(lastExpertId, tokenCount);
lastExpertId++;
}
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
0};
if (this->expertTokensCountOrCumsumFlag > 0) {
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
}
expertTokensCopyOutQueue.FreeTensor(expertTokensCount);
expandedExpertIdxCopyOutQueue.FreeTensor(expandedExpertIdx);
}
__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutEmpty() {
LocalTensor<int32_t> outLocal = expandedExpertIdxCopyOutQueue.DeQue<int32_t>();
expandedExpertIdxCopyOutQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2FullLoadQuantBase::InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX,
GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum,
GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData,
TPipe* tPipe) {
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->k = tilingData->k;
this->n = tilingData->n;
this->cols = tilingData->cols;
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->perCoreRows = this->gatherOutTilingData->perCoreRows;
this->activateRows = this->gatherOutTilingData->activateRows;
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
this->coreRows = this->gatherOutTilingData->lastCoreRows;
} else {
this->coreRows = this->gatherOutTilingData->perCoreRows;
}
this->expertNum = tilingData->expertNum;
this->dropPadMode = tilingData->dropPadMode;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
this->totalLength = tilingData->n * tilingData->k;
this->pipe = tPipe;
expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
if (this->expertTokensCountOrCumsumFlag > 0) {
// dropless
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
Align(this->expertNum, sizeof(int32_t)));
}
int64_t kvFactor = 2;
int64_t buffSize = this->sortNum * sizeof(int32_t);
pipe->InitBuffer(expandedRowIdxCopyOutQueue, bufferNum, buffSize);
pipe->InitBuffer(expandedExpertIdxCopyOutQueue, bufferNum, buffSize);
pipe->InitBuffer(expertTokensCopyOutQueue, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
pipe->InitBuffer(expandDstToSrcRowQueue, bufferNum, buffSize);
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
}
__aicore__ inline void MoeV2FullLoadQuantBase::ProcessBase() {
if (this->blockIdx < this->needCoreNum) {
CopyIn();
SortCompute();
if (this->blockIdx == 0) {
CopyOutIdx();
}
if (this->blockIdx == this->needCoreNum - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
ComputeExpertTokenCountOrCumsum();
} else {
CopyOutEmpty();
}
}
}
} // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_FULL_LOAD_QUANT_BASE_H

View File

@@ -0,0 +1,568 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_gather_dynamic_quant.h
* \brief
*/
#ifndef MOE_V2_GATHER_DYNAMIC_QUANT_H
#define MOE_V2_GATHER_DYNAMIC_QUANT_H
#include "moe_v2_common.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T>
class MoeV2GatherDynamicQuant {
public:
__aicore__ inline MoeV2GatherDynamicQuant(){};
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
GM_ADDR dynamicQuantScale, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInExpandedRowIdx(int64_t progress);
__aicore__ inline void CopyInExpandedExpertIdx(int64_t progress);
__aicore__ inline void CopyOutXQuant1H(int64_t progress);
__aicore__ inline void CopyOutXQuantEH(int64_t progress);
__aicore__ inline void Compute(LocalTensor<float>& smoothLocal);
__aicore__ inline void CopyOutPartialXQuantEH(int64_t progress);
__aicore__ inline void CopyOutPartialXQuant1H(int64_t progress);
__aicore__ inline float ComputeMax(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal,
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx,
int64_t j);
__aicore__ inline void ComputeScale(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal, float scaleTemp,
int64_t dstIndex, int64_t j);
private:
TPipe* pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inputXInQueue;
TQue<QuePosition::VECIN, BUFFER_NUM> smoothInQueue;
TQue<QuePosition::VECIN, BUFFER_NUM> expandRowIdxInQueue;
TQue<QuePosition::VECOUT, 1> calcQueue;
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
GlobalTensor<T> inputXGm;
GlobalTensor<int8_t> expandedXGm;
GlobalTensor<int32_t> expandedRowIdxGm;
GlobalTensor<float> quantSmoothGm;
GlobalTensor<float> dynamicQuantScaleGm;
GlobalTensor<float> quantSrcGm;
GlobalTensor<int32_t> expandedExpertIdxGm;
GlobalTensor<int32_t> sortedRowIdxGm;
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
int64_t needCoreNum;
int64_t blockIdx;
int64_t cols;
int64_t n;
int64_t k;
int64_t totalLength;
int64_t activateRows;
int64_t currentLoopRows;
int64_t currentLoopRowsAlign;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t rowLoops;
int64_t colsTileLength;
int64_t perLoopCols;
int64_t perLoopColsAlign;
int64_t lastLoopCols;
int64_t colLoops;
int64_t dropPadMode;
int64_t smoothType;
int64_t indicesOffset;
int64_t inputOffset;
int64_t outOffset;
};
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyInExpandedRowIdx(int64_t progress) {
this->indicesOffset = progress * this->perLoopRows;
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
expandRowIdxInQueue.EnQue<int32_t>(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyInExpandedExpertIdx(int64_t progress) {
this->indicesOffset = progress * this->perLoopRows;
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(indicesLocal, sortedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
DataCopyPad(indicesLocal[currentLoopRowsAlign], expandedExpertIdxGm[indicesOffset], dataCopyParams,
dataCopyPadParams);
expandRowIdxInQueue.EnQue<int32_t>(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Compute(LocalTensor<float>& smoothLocal) {
LocalTensor<float> inLocal = inputXInQueue.DeQue<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
pipe_barrier(PIPE_V);
}
if (smoothType != 0) {
Mul(inLocal, inLocal, smoothLocal, this->cols);
pipe_barrier(PIPE_V);
}
Abs(tempLocal, inLocal, this->cols);
pipe_barrier(PIPE_V);
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols);
pipe_barrier(PIPE_V);
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
Duplicate<float>(tempLocal, maxValue, this->cols);
pipe_barrier(PIPE_V);
Div(tempLocal, inLocal, tempLocal, this->cols);
pipe_barrier(PIPE_V);
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols);
pipe_barrier(PIPE_V);
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols);
calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
int64_t curLoopRow = 0;
int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
LocalTensor<float> smoothLocal;
if (smoothType == 1) {
smoothLocal = smoothInQueue.AllocTensor<float>();
DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0});
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
}
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
LocalTensor<T> inLocal = inputXInQueue.AllocTensor<T>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal[perLoopColsAlign], inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0});
}
inputXInQueue.EnQue<T>(inLocal);
// 计算quant
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
curLoopRow++;
initialRow++;
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue;
}
DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams);
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
}
inputXInQueue.FreeTensor(inLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
if (smoothType == 1) {
smoothInQueue.FreeTensor(smoothLocal);
}
expandRowIdxInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuantEH(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(float)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->perLoopCols * sizeof(int8_t)), 0, 0, 0};
int32_t lastExpertIdx = -1;
LocalTensor<T> inLocal = inputXInQueue.AllocTensor<T>();
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
for (int64_t i = 0; i < this->currentLoopRows; i++) {
int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) {
break;
}
int32_t srcIdx = indicesLocal.GetValue(i);
int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i);
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
}
inputXInQueue.EnQue<T>(inLocal);
if (expertIdx != lastExpertIdx) {
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0});
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
lastExpertIdx = expertIdx;
}
Compute(smoothLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0});
LocalTensor<int8_t> outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm[(rowOffset + i) * this->cols], outLocal, copyOutParams);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
inputXInQueue.FreeTensor(inLocal);
smoothInQueue.FreeTensor(smoothLocal);
expandRowIdxInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline float MoeV2GatherDynamicQuant<T>::ComputeMax(LocalTensor<float>& inLocal,
LocalTensor<float>& tempLocal,
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx,
int32_t expertIdx, int64_t j) {
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
DataCopyExtParams intriParamsT{1, static_cast<uint32_t>(colsTileLength * sizeof(T)), 0, 0, 0};
DataCopyExtParams intriParamsFp32{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
if constexpr (!IsSameType<T, float>::value) {
DataCopyPad(inLocal.ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols],
intriParamsT, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0});
}
inputXInQueue.EnQue<float>(inLocal);
inLocal = inputXInQueue.DeQue<float>();
if (smoothType != 0) {
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32,
{false, 0, 0, 0});
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
}
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength);
pipe_barrier(PIPE_V);
}
if (smoothType != 0) {
Mul(inLocal, inLocal, smoothLocal, colsTileLength);
pipe_barrier(PIPE_V);
}
Abs(tempLocal, inLocal, colsTileLength);
pipe_barrier(PIPE_V);
ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength);
DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32);
smoothInQueue.FreeTensor(smoothLocal);
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
return dynamicQuantLocal.GetValue(8);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::ComputeScale(LocalTensor<float>& inLocal,
LocalTensor<float>& tempLocal, float scaleTemp,
int64_t dstIndex, int64_t j) {
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(colsTileLength * sizeof(int8_t)), 0, 0, 0};
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0});
inputXInQueue.EnQue<float>(inLocal);
inLocal = inputXInQueue.DeQue<float>();
Duplicate<float>(tempLocal, scaleTemp, colsTileLength);
pipe_barrier(PIPE_V);
Div(tempLocal, inLocal, tempLocal, colsTileLength);
pipe_barrier(PIPE_V);
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength);
pipe_barrier(PIPE_V);
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, colsTileLength);
inputXOutQueue.EnQue(outLocal);
outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams);
inputXOutQueue.FreeTensor(outLocal);
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutPartialXQuantEH(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
for (int64_t i = 0; i < this->currentLoopRows; i++) {
int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) {
break;
}
int32_t srcIdx = indicesLocal.GetValue(i);
int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i);
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
uint32_t tmp = 0xFF7FFFFF;
float reduceMax = *((float*)&tmp);
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j);
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
}
float scaleTemp = reduceMax / 127.0f;
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
scaleOutQueue.EnQue(quantScaleLocal);
quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0});
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
ComputeScale(inLocal, tempLocal, scaleTemp, rowOffset + i, j);
}
inputXInQueue.FreeTensor(inLocal);
calcQueue.FreeTensor(tempLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
expandRowIdxInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutPartialXQuant1H(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandRowIdxInQueue.DeQue<int32_t>();
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
int64_t curLoopRow = 0;
int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
uint32_t tmp = 0xFF7FFFFF;
float reduceMax = *((float*)&tmp);
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, row, 0, j);
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
}
float scaleTemp = reduceMax / 127.0f;
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
scaleOutQueue.EnQue(quantScaleLocal);
quantScaleLocal = scaleOutQueue.DeQue<float>();
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
curLoopRow++;
initialRow++;
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue;
}
DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0});
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
ComputeScale(inLocal, tempLocal, scaleTemp, outIndex, j);
}
}
inputXInQueue.FreeTensor(inLocal);
calcQueue.FreeTensor(tempLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
expandRowIdxInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx,
GM_ADDR expandedX, GM_ADDR dynamicQuantScale, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData,
TPipe* tPipe) {
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->activateRows = this->gatherOutTilingData->activateRows;
this->cols = tilingData->cols;
this->n = tilingData->n;
this->k = tilingData->k;
this->totalLength = tilingData->n * tilingData->k;
this->dropPadMode = tilingData->dropPadMode;
this->smoothType = tilingData->smoothType;
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
this->coreRows = this->gatherOutTilingData->lastCoreRows;
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
} else {
this->coreRows = this->gatherOutTilingData->perCoreRows;
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
}
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
this->colLoops = this->gatherOutTilingData->colLoops;
this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T));
inputXGm.SetGlobalBuffer((__gm__ T*)inputX);
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
expandedRowIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth);
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
expandedExpertIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + this->blockIdx * this->gatherOutTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
sortedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) +
this->blockIdx * this->gatherOutTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
if (this->cols > 1) {
quantSrcGm.SetGlobalBuffer(
(__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2 + this->blockIdx * this->cols,
this->cols * sizeof(float));
}
this->currentLoopRowsAlign = Align(this->perLoopRows, sizeof(int32_t));
int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T));
perLoopColsAlignBytes =
Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES));
pipe->InitBuffer(expandRowIdxInQueue, BUFFER_NUM, 2 * AlignBytes(this->perLoopRows, sizeof(int32_t)));
pipe->InitBuffer(inputXInQueue, BUFFER_NUM, perLoopColsAlignBytes);
pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
}
template <typename T>
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
if (this->blockIdx < this->needCoreNum) {
currentLoopRows = perLoopRows;
if (colLoops > 1) { // 一行无法全载需要workspace
if (smoothType == 2) {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedExpertIdx(loop);
CopyOutPartialXQuantEH(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedExpertIdx(this->rowLoops - 1);
CopyOutPartialXQuantEH(this->rowLoops - 1);
} else {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedRowIdx(loop);
CopyOutPartialXQuant1H(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedRowIdx(this->rowLoops - 1);
CopyOutPartialXQuant1H(this->rowLoops - 1);
}
} else { // 一行可以全载
if (smoothType == 2) {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedExpertIdx(loop);
CopyOutXQuantEH(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedExpertIdx(this->rowLoops - 1);
CopyOutXQuantEH(this->rowLoops - 1);
} else {
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
CopyInExpandedRowIdx(loop);
CopyOutXQuant1H(loop);
}
currentLoopRows = lastLoopRows;
CopyInExpandedRowIdx(this->rowLoops - 1);
CopyOutXQuant1H(this->rowLoops - 1);
}
}
}
}
} // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_GATHER_DYNAMIC_QUANT_H

View File

@@ -0,0 +1,181 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_gather_out.h
* \brief
*/
#ifndef INNER_MOE_V2_GATHER_OUT_H
#define INNER_MOE_V2_GATHER_OUT_H
#include "moe_v2_common.h"
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
constexpr int64_t BUFFER_NUM = 2;
template <typename T>
class MoeV2GatherOut {
public:
__aicore__ inline MoeV2GatherOut(){};
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace,
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInIndices(int64_t progress);
__aicore__ inline void CopyOut(int64_t progress);
private:
TPipe* pipe;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, BUFFER_NUM> inputActivationsCopyInQueue;
TQue<QuePosition::VECIN, BUFFER_NUM> expandDstToSrcRowCopyInQueue;
GlobalTensor<T> inputXGm;
GlobalTensor<T> expandedXGm;
GlobalTensor<int32_t> expandedRowIdxGm;
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
int64_t needCoreNum;
int64_t blockIdx;
int64_t cols;
int64_t n;
int64_t k;
int64_t activateRows;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t rowLoops;
int64_t colsTileLength;
int64_t perLoopCols;
int64_t lastLoopCols;
int64_t colLoops;
int64_t dropPadMode;
int64_t indicesOffset;
int64_t inputOffset;
int64_t outOffset;
};
template <typename T>
__aicore__ inline void MoeV2GatherOut<T>::CopyInIndices(int64_t progress) {
this->indicesOffset = progress * this->perLoopRows;
LocalTensor<int32_t> indicesLocal = expandDstToSrcRowCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
expandDstToSrcRowCopyInQueue.EnQue<int32_t>(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherOut<T>::CopyOut(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandDstToSrcRowCopyInQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
colsTileLength = this->perLoopCols;
for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) {
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
int64_t curLoopRow = 0;
if (colsLoop == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
LocalTensor<T> inLocal = inputActivationsCopyInQueue.AllocTensor<T>();
// input row position
inputOffset = row * this->cols + colsLoop * this->perLoopCols;
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams<T> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams);
SetWaitFlag<HardEvent::MTE2_MTE3>(HardEvent::MTE2_MTE3);
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
curLoopRow++;
initialRow++;
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue;
}
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
#ifdef __CCE_KT_TEST__
// CPU孪生调试无法使用多核同步可能导致index为未初始化的脏数据因此需要特殊处理
if (outOffset > expandedXGm.GetSize()) {
continue;
}
#endif
DataCopyPad(expandedXGm[outOffset], inLocal, intriParams);
}
inputActivationsCopyInQueue.FreeTensor(inLocal);
}
}
expandDstToSrcRowCopyInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherOut<T>::Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
GM_ADDR workspace, const InnerMoeInitRoutingV2TilingData* tilingData,
TPipe* tPipe) {
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->activateRows = this->gatherOutTilingData->activateRows;
this->cols = tilingData->cols;
this->n = tilingData->n;
this->k = tilingData->k;
this->dropPadMode = tilingData->dropPadMode;
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
this->coreRows = this->gatherOutTilingData->lastCoreRows;
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
} else {
this->coreRows = this->gatherOutTilingData->perCoreRows;
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
}
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
this->colLoops = this->gatherOutTilingData->colLoops;
inputXGm.SetGlobalBuffer((__gm__ T*)inputX, this->coreRows * this->cols);
expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, tilingData->n * tilingData->k * this->cols);
expandedRowIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
pipe->InitBuffer(inputActivationsCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T)));
pipe->InitBuffer(expandDstToSrcRowCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t)));
}
template <typename T>
__aicore__ inline void MoeV2GatherOut<T>::Process() {
if (this->blockIdx < this->needCoreNum) {
currentLoopRows = perLoopRows;
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
if (loop == this->rowLoops - 1) {
currentLoopRows = lastLoopRows;
}
CopyInIndices(loop);
CopyOut(loop);
}
}
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_GATHER_OUT_H

View File

@@ -0,0 +1,235 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_gather_quant.h
* \brief
*/
#ifndef MOE_V2_GATHER_QUANT_H
#define MOE_V2_GATHER_QUANT_H
#include "moe_v2_common.h"
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
constexpr int64_t BUFFER_NUM = 2;
template <typename T>
class MoeV2GatherQuant {
public:
__aicore__ inline MoeV2GatherQuant(){};
__aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyInIndices(int64_t progress);
__aicore__ inline void Compute();
__aicore__ inline void CopyOut(int64_t progress);
private:
TPipe* pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inputXCopyInQueue;
TQue<QuePosition::VECIN, BUFFER_NUM> expandRowIdxCopyInQueue;
TQue<QuePosition::VECOUT, BUFFER_NUM> inputXCopyOutQueue;
TQue<QuePosition::VECOUT, 1> floatQueue;
TQue<QuePosition::VECOUT, 1> halfQueue;
GlobalTensor<T> inputXGm;
GlobalTensor<int8_t> expandedXGm;
GlobalTensor<int32_t> expandedRowIdxGm;
GlobalTensor<float> scaleGm;
GlobalTensor<float> offsetGm;
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData;
int64_t needCoreNum;
int64_t blockIdx;
int64_t cols;
int64_t n;
int64_t k;
int64_t activateRows;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t rowLoops;
int64_t colsTileLength;
int64_t perLoopCols;
int64_t lastLoopCols;
int64_t colLoops;
int64_t dropPadMode;
float scale;
float offset;
int64_t indicesOffset;
int64_t inputOffset;
int64_t outOffset;
};
template <typename T>
__aicore__ inline void MoeV2GatherQuant<T>::CopyInIndices(int64_t progress) {
this->indicesOffset = progress * this->perLoopRows;
LocalTensor<int32_t> indicesLocal = expandRowIdxCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams);
expandRowIdxCopyInQueue.EnQue<int32_t>(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherQuant<T>::Compute() {
LocalTensor<T> inLocal = inputXCopyInQueue.DeQue<T>();
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.AllocTensor<int8_t>();
LocalTensor<float> floatLocal = floatQueue.AllocTensor<float>();
LocalTensor<half> halfLocal = halfQueue.AllocTensor<half>();
uint32_t elements = Align(this->colsTileLength, sizeof(T));
if constexpr (IsSameType<T, bfloat16_t>::value) {
Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> intLocal = floatLocal.ReinterpretCast<int32_t>();
Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements);
pipe_barrier(PIPE_V);
SetDeqScale((half)1.000000e+00f);
pipe_barrier(PIPE_V);
Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements);
pipe_barrier(PIPE_V);
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
} else if constexpr (IsSameType<T, float>::value) {
Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements);
pipe_barrier(PIPE_V);
Muls(halfLocal, halfLocal, static_cast<half>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(halfLocal, halfLocal, static_cast<half>(this->offset), elements);
pipe_barrier(PIPE_V);
Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements);
} else {
Muls(inLocal, inLocal, static_cast<T>(this->scale), elements);
pipe_barrier(PIPE_V);
Adds(inLocal, inLocal, static_cast<T>(this->offset), elements);
pipe_barrier(PIPE_V);
Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements);
}
inputXCopyOutQueue.EnQue(outLocal);
floatQueue.FreeTensor(floatLocal);
halfQueue.FreeTensor(halfLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherQuant<T>::CopyOut(int64_t progress) {
LocalTensor<int32_t> indicesLocal = expandRowIdxCopyInQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
colsTileLength = this->perLoopCols;
for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) {
int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress;
int64_t curLoopRow = 0;
if (colsLoop == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
int64_t currentLoopStartRow = initialRow / this->k;
int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k;
for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) {
LocalTensor<T> inLocal = inputXCopyInQueue.AllocTensor<T>();
// input row position
inputOffset = row * this->cols + colsLoop * this->perLoopCols;
DataCopyExtParams dataCopyParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams<T> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams);
inputXCopyInQueue.EnQue<T>(inLocal);
Compute();
LocalTensor<int8_t> outLocal = inputXCopyOutQueue.DeQue<int8_t>();
DataCopyExtParams intriParams{1, static_cast<uint32_t>(this->colsTileLength * sizeof(int8_t)), 0, 0, 0};
while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) {
int32_t outIndex = indicesLocal.GetValue(curLoopRow);
curLoopRow++;
initialRow++;
if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) {
continue;
}
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
DataCopyPad(expandedXGm[outOffset], outLocal, intriParams);
}
inputXCopyInQueue.FreeTensor(inLocal);
inputXCopyOutQueue.FreeTensor(outLocal);
}
}
expandRowIdxCopyInQueue.FreeTensor(indicesLocal);
}
template <typename T>
__aicore__ inline void MoeV2GatherQuant<T>::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx,
GM_ADDR expandedX, GM_ADDR workspace,
const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) {
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp);
this->needCoreNum = this->gatherOutTilingData->needCoreNum;
this->activateRows = this->gatherOutTilingData->activateRows;
this->cols = tilingData->cols;
this->n = tilingData->n;
this->k = tilingData->k;
this->dropPadMode = tilingData->dropPadMode;
if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) {
this->coreRows = this->gatherOutTilingData->lastCoreRows;
this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->lastCoreLoops;
} else {
this->coreRows = this->gatherOutTilingData->perCoreRows;
this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows;
this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows;
this->rowLoops = this->gatherOutTilingData->perCoreLoops;
}
this->perLoopCols = this->gatherOutTilingData->perLoopCols;
this->lastLoopCols = this->gatherOutTilingData->lastLoopCols;
this->colLoops = this->gatherOutTilingData->colLoops;
inputXGm.SetGlobalBuffer((__gm__ T*)inputX);
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX);
expandedRowIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1);
offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1);
this->scale = scaleGm.GetValue(0);
this->offset = offsetGm.GetValue(0);
pipe->InitBuffer(inputXCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T)));
pipe->InitBuffer(inputXCopyOutQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(int8_t)));
pipe->InitBuffer(expandRowIdxCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t)));
pipe->InitBuffer(floatQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(halfQueue, 1, AlignBytes(this->perLoopCols, sizeof(half)));
}
template <typename T>
__aicore__ inline void MoeV2GatherQuant<T>::Process() {
if (this->blockIdx < this->needCoreNum) {
currentLoopRows = perLoopRows;
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
if (loop == this->rowLoops - 1) {
currentLoopRows = lastLoopRows;
}
CopyInIndices(loop);
CopyOut(loop);
}
}
}
} // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_GATHER_QUANT_H

View File

@@ -0,0 +1,312 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/* !
* \file moe_v2_init_routing_fullload.h
* \brief
*/
#ifndef INNER_MOE_V2_FULL_LOAD_H
#define INNER_MOE_V2_FULL_LOAD_H
#include "moe_v2_mrgsort.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T>
class MoeV2FullLoad : public MoeV2SortBase {
public:
__aicore__ inline MoeV2FullLoad(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn();
__aicore__ inline void SortCompute();
__aicore__ inline void CopyOutIdx();
__aicore__ inline void CopyOutEmpty();
__aicore__ inline void CopyOutX();
__aicore__ inline void ComputeExpertTokenCountOrCumsum();
private:
int64_t sortNum_;
const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_;
int64_t blockIdx_;
int64_t needCoreNum_;
int64_t coreRows_;
int64_t perCoreRows_;
int64_t k_;
int64_t n_;
int64_t cols_;
int64_t activateRows_;
int64_t expertNum;
int64_t expertCapacity;
TQue<QuePosition::VECIN, 1> xCopyInQueue_;
TQue<QuePosition::VECOUT, 1> expandedRowIdxCopyOutQueue_;
TQue<QuePosition::VECOUT, 1> expandedExpertIdxCopyOutQueue_;
TQue<QuePosition::VECOUT, 1> expandDstToSrcRowQueue_;
TQue<QuePosition::VECOUT, 1> expertTokensCopyOutQueue_;
GlobalTensor<T> xGm_;
GlobalTensor<int32_t> expertIdxGm_;
GlobalTensor<T> expandedXGm_;
GlobalTensor<int32_t> expandedRowIdxGm_;
GlobalTensor<int32_t> expandedExpertIdxGm_;
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
int64_t expertTokensCountOrCumsumFlag = 0;
int64_t expertTokensBeforeCapacityFlag = 0;
int64_t dropPadMode = 0;
};
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::CopyIn() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams);
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
sortDataCopyInQueue.EnQue(inLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::SortCompute() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertIdxLocal = inLocal[0];
LocalTensor<float> expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast<float>();
Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
LocalTensor<float> concatLocal;
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<uint32_t> rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast<uint32_t>();
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum_));
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<float> expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor<float>();
LocalTensor<uint32_t> expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor<uint32_t>();
LocalTensor<float> expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast<float>();
Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast<int32_t>(), RoundMode::CAST_ROUND,
this->totalLength);
pipe_barrier(PIPE_V);
Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> expandedExpertIdxLocalInt32;
expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast<int32_t>();
Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength);
pipe_barrier(PIPE_V);
expandedExpertIdxCopyOutQueue_.EnQue<int32_t>(expandedExpertIdxLocalInt32);
LocalTensor<uint32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor<uint32_t>();
LocalTensor<uint32_t> expandedRowIdxU32 = expandedRowIdx.ReinterpretCast<uint32_t>();
Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength);
pipe_barrier(PIPE_V);
ArithProgression<int32_t>(inLocal[this->sortNum_], 0, 1, this->totalLength);
pipe_barrier(PIPE_V);
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Sort<float, true>(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
expandedRowIdxCopyOutQueue_.EnQue<uint32_t>(expandedRowIdx);
sortDataCopyInQueue.FreeTensor(inLocal);
expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::CopyOutIdx() {
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = this->totalLength * sizeof(int32_t);
DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams);
expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::ComputeExpertTokenCountOrCumsum() {
LocalTensor<int32_t> expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
LocalTensor<int32_t> expertTokensCount = expertTokensCopyOutQueue_.AllocTensor<int32_t>();
int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t));
Duplicate(expertTokensCount, 0, expertNumAlign);
SetWaitFlag<HardEvent::V_S>(HardEvent::V_S);
int32_t lastExpertId = expandedExpertIdx.GetValue(0);
int64_t tokenCount = 0;
int64_t lastExpertCount = 0;
for (int64_t i = 0; i < this->totalLength; i++) {
int32_t curExpertId = expandedExpertIdx.GetValue(i);
tokenCount++;
while (lastExpertId < curExpertId) {
expertTokensCount.SetValue(lastExpertId, tokenCount - 1);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) {
tokenCount = 1;
}
lastExpertId++;
}
}
expertTokensCount.SetValue(lastExpertId, tokenCount);
if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) {
lastExpertId++;
while (lastExpertId < this->expertNum) {
expertTokensCount.SetValue(lastExpertId, tokenCount);
lastExpertId++;
}
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->expertNum * sizeof(int32_t)), 0, 0,
0};
if (this->expertTokensCountOrCumsumFlag > 0) {
DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams);
}
expertTokensCopyOutQueue_.FreeTensor(expertTokensCount);
expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::CopyOutX() {
LocalTensor<T> xLocal = xCopyInQueue_.AllocTensor<T>();
LocalTensor<int32_t> expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue<int32_t>();
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = this->cols_ * sizeof(T);
int64_t inFactor = Align(this->cols_, sizeof(T));
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t startXRow = curRowsStart / this->k_;
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
DataCopyExtParams dataXCopyParams{static_cast<uint16_t>(endXRow - startXRow + 1),
static_cast<uint32_t>(this->cols_ * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams<T> dataXCopyPadParams{false, 0, 0, 0};
DataCopyPad(xLocal, xGm_[startXRow * this->cols_], dataXCopyParams, dataXCopyPadParams);
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
int64_t k = 0;
for (int64_t i = startXRow; i <= endXRow; i++) {
for (; k < this->perCoreRows_ && curRowsStart / this->k_ == i; curRowsStart++, k++) {
int32_t outIndex = expandedRowIdx.GetValue(curRowsStart);
if (outIndex < this->activateRows_) {
DataCopyPad(expandedXGm_[outIndex * this->cols_], xLocal[(i - startXRow) * inFactor], intriParams);
}
}
}
expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx);
xCopyInQueue_.FreeTensor(xLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::CopyOutEmpty() {
LocalTensor<int32_t> outLocal = expandedExpertIdxCopyOutQueue_.DeQue<int32_t>();
expandedExpertIdxCopyOutQueue_.FreeTensor(outLocal);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx,
GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace,
const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe) {
this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp);
this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num();
this->k_ = tilingData->k;
this->n_ = tilingData->n;
this->cols_ = tilingData->cols;
this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum;
this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows;
this->activateRows_ = this->gatherOutTilingData_->activateRows;
if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) {
this->coreRows_ = this->gatherOutTilingData_->lastCoreRows;
} else {
this->coreRows_ = this->gatherOutTilingData_->perCoreRows;
}
this->expertNum = tilingData->expertNum;
this->dropPadMode = tilingData->dropPadMode;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
this->totalLength = tilingData->n * tilingData->k;
this->pipe = tPipe;
xGm_.SetGlobalBuffer((__gm__ T*)x);
expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
expandedXGm_.SetGlobalBuffer((__gm__ T*)expandedX);
expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength);
if (this->expertTokensCountOrCumsumFlag > 0) {
// dropless
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
Align(this->expertNum, sizeof(int32_t)));
}
int64_t kvFactor = 2;
int64_t buffSize = this->sortNum_ * sizeof(int32_t);
int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_;
int64_t startXRow = curRowsStart / this->k_;
int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_;
pipe->InitBuffer(xCopyInQueue_, bufferNum, AlignBytes(this->cols_, sizeof(T)) * (endXRow - startXRow + 1));
pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize);
pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize);
pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t)));
pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize);
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor);
pipe->InitBuffer(tempBuffer, buffSize * kvFactor);
pipe->InitBuffer(sortedBuffer, buffSize * kvFactor);
}
template <typename T>
__aicore__ inline void MoeV2FullLoad<T>::Process() {
if (this->blockIdx_ < this->needCoreNum_) {
CopyIn();
SortCompute();
if (this->blockIdx_ == 0) {
CopyOutIdx();
}
if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
ComputeExpertTokenCountOrCumsum();
} else {
CopyOutEmpty();
}
CopyOutX();
}
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_FULL_LOAD_H

View File

@@ -0,0 +1,189 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_mrgsort.h
* \brief
*/
#ifndef INNER_MOE_V2_MRGSORT_H
#define INNER_MOE_V2_MRGSORT_H
#include "moe_v2_common.h"
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
struct MoeV2MrgsortParam {
int64_t perListElements;
int64_t lastListElements;
int64_t oneLoopMaxElements;
};
class MoeV2Mrgsort {
public:
__aicore__ inline MoeV2Mrgsort(){};
__aicore__ inline void Init(MoeV2MrgsortParam* param);
__aicore__ inline void Process();
__aicore__ inline void SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput);
__aicore__ inline void SetOutput(GlobalTensor<float>& gmOutput, LocalTensor<float>& ubOutput);
private:
__aicore__ inline void CopyIn();
__aicore__ inline void UpdateMrgParam();
__aicore__ inline void MrgsortCompute();
__aicore__ inline void UpdateSortInfo();
__aicore__ inline void CopyOut();
__aicore__ inline void ClearCache();
private:
MoeV2MrgsortParam* param = nullptr;
GlobalTensor<float> gmInputs[4];
GlobalTensor<float> gmOutput;
LocalTensor<float> ubInputs[4];
LocalTensor<float> ubOutput;
int64_t listNum{0};
int64_t remainListNum{0};
int64_t outOffset{0};
int64_t offsets[4];
int64_t listRemainElements[4];
int64_t lengths[4];
int64_t allRemainElements{0};
int64_t curLoopSortedNum{0};
// for MrgSort
uint16_t validBitTail{0};
uint16_t elementCountListTail[4];
uint32_t listSortedNums[4];
LocalTensor<float> tmpUbInputs[4];
};
__aicore__ inline void MoeV2Mrgsort::ClearCache() {
this->listNum = 0;
this->allRemainElements = 0;
this->outOffset = 0;
}
__aicore__ inline void MoeV2Mrgsort::SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput) {
this->gmInputs[listNum] = gmInput;
this->ubInputs[listNum] = ubInput;
this->listNum += 1;
}
__aicore__ inline void MoeV2Mrgsort::SetOutput(GlobalTensor<float>& gmOutput, LocalTensor<float>& ubOutput) {
this->gmOutput = gmOutput;
this->ubOutput = ubOutput;
}
__aicore__ inline void MoeV2Mrgsort::UpdateMrgParam() {
if (this->remainListNum == MERGE_LIST_TWO) {
elementCountListTail[MERGE_LIST_IDX_TWO] = 0;
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
validBitTail = 0b0011;
} else if (this->remainListNum == MERGE_LIST_THREE) {
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
validBitTail = 0b0111;
} else if (this->remainListNum == MERGE_LIST_FOUR) {
validBitTail = 0b1111;
} else {
validBitTail = 0b0001;
}
}
__aicore__ inline void MoeV2Mrgsort::CopyIn() {
this->remainListNum = 0;
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
for (int64_t i = 0, j = 0; i < listNum; i++) {
lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]);
if (lengths[i] > 0) {
DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen<float>(lengths[i]), sizeof(float)));
tmpUbInputs[j] = this->ubInputs[i];
elementCountListTail[j] = lengths[i];
this->remainListNum += 1;
j++;
}
}
}
__aicore__ inline void MoeV2Mrgsort::MrgsortCompute() {
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
if (this->remainListNum == MERGE_LIST_TWO) {
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]);
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else if (this->remainListNum == MERGE_LIST_THREE) {
MrgSortSrcList sortListTail =
MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]);
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else if (this->remainListNum == MERGE_LIST_FOUR) {
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO],
tmpUbInputs[MERGE_LIST_IDX_THREE]);
MrgSort<float, true>(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else {
DataCopy(this->ubOutput, this->tmpUbInputs[0], Align(GetSortLen<float>(elementCountListTail[0]), sizeof(float)));
listSortedNums[0] = elementCountListTail[0];
}
}
__aicore__ inline void MoeV2Mrgsort::UpdateSortInfo() {
curLoopSortedNum = 0;
for (int64_t i = 0, j = 0; i < listNum; i++) {
if (lengths[i] > 0) {
// update remain size
listRemainElements[i] -= listSortedNums[j];
allRemainElements -= listSortedNums[j];
// update offset
offsets[i] += GetSortOffset<float>(listSortedNums[j]);
// update current loop sorted nums
curLoopSortedNum += listSortedNums[j];
j += 1;
}
}
}
__aicore__ inline void MoeV2Mrgsort::CopyOut() {
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = GetSortLen<float>(curLoopSortedNum) * sizeof(float);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams);
outOffset += GetSortLen<float>(curLoopSortedNum);
}
__aicore__ inline void MoeV2Mrgsort::Init(MoeV2MrgsortParam* param) {
this->param = param;
this->remainListNum = listNum;
for (int64_t i = 0; i < listNum; i++) {
offsets[i] = GetSortOffset<float>(param->perListElements * i);
if (i == listNum - 1) {
listRemainElements[i] = param->lastListElements;
} else {
listRemainElements[i] = param->perListElements;
}
allRemainElements += listRemainElements[i];
}
}
__aicore__ inline void MoeV2Mrgsort::Process() {
for (; allRemainElements > 0;) {
CopyIn();
UpdateMrgParam();
MrgsortCompute();
UpdateSortInfo();
CopyOut();
}
ClearCache();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_MRGSORT_H

View File

@@ -0,0 +1,213 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_mrgsort_out.h
* \brief
*/
#ifndef INNER_MOE_V2_MRGSORT_OUT_H
#define INNER_MOE_V2_MRGSORT_OUT_H
#include "moe_v2_mrgsort.h"
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2MrgsortOut {
public:
__aicore__ inline MoeV2MrgsortOut(){};
__aicore__ inline void Init(MoeV2MrgsortParam* param, TPipe* tPipe);
__aicore__ inline void Process();
__aicore__ inline void SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput);
__aicore__ inline void SetOutput(GlobalTensor<int32_t>& gmOutput1, GlobalTensor<int32_t>& gmOutput2,
LocalTensor<float>& ubOutput1, LocalTensor<float>& ubOutput2);
__aicore__ inline void SetBuffer(LocalTensor<float>& tempBuffer);
private:
__aicore__ inline void CopyIn();
__aicore__ inline void UpdateMrgParam();
__aicore__ inline void MrgsortCompute();
__aicore__ inline void UpdateSortInfo();
__aicore__ inline void Extract();
__aicore__ inline void CopyOut();
__aicore__ inline void ClearCache();
private:
MoeV2MrgsortParam* param = nullptr;
GlobalTensor<float> gmInputs[4];
GlobalTensor<int32_t> gmOutput1;
GlobalTensor<int32_t> gmOutput2;
LocalTensor<float> ubInputs[4];
LocalTensor<float> tempBuffer;
// for extract
LocalTensor<float> ubOutput1;
LocalTensor<uint32_t> ubOutput2;
// for copy out
LocalTensor<int32_t> ubOutputInt1;
LocalTensor<int32_t> ubOutputInt2;
int64_t listNum{0};
int64_t remainListNum{0};
int64_t outOffset{0};
int64_t offsets[4];
int64_t listRemainElements[4];
int64_t lengths[4];
int64_t allRemainElements{0};
int64_t curLoopSortedNum{0};
// for MrgSort
uint16_t validBitTail;
uint16_t elementCountListTail[4];
uint32_t listSortedNums[4];
LocalTensor<float> tmpUbInputs[4];
};
__aicore__ inline void MoeV2MrgsortOut::ClearCache() {
this->listNum = 0;
this->allRemainElements = 0;
this->outOffset = 0;
}
__aicore__ inline void MoeV2MrgsortOut::SetInput(GlobalTensor<float>& gmInput, LocalTensor<float>& ubInput) {
this->gmInputs[listNum] = gmInput;
this->ubInputs[listNum] = ubInput;
this->listNum += 1;
}
__aicore__ inline void MoeV2MrgsortOut::SetOutput(GlobalTensor<int32_t>& gmOutput1, GlobalTensor<int32_t>& gmOutput2,
LocalTensor<float>& ubOutput1, LocalTensor<float>& ubOutput2) {
this->gmOutput1 = gmOutput1;
this->ubOutput1 = ubOutput1;
this->ubOutputInt1 = ubOutput1.ReinterpretCast<int32_t>();
this->gmOutput2 = gmOutput2;
this->ubOutput2 = ubOutput2.ReinterpretCast<uint32_t>();
this->ubOutputInt2 = ubOutput2.ReinterpretCast<int32_t>();
}
__aicore__ inline void MoeV2MrgsortOut::SetBuffer(LocalTensor<float>& tempBuffer) {
this->tempBuffer = tempBuffer;
}
__aicore__ inline void MoeV2MrgsortOut::UpdateMrgParam() {
if (this->remainListNum == MERGE_LIST_TWO) {
elementCountListTail[MERGE_LIST_IDX_TWO] = 0;
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
validBitTail = 0b0011;
} else if (this->remainListNum == MERGE_LIST_THREE) {
elementCountListTail[MERGE_LIST_IDX_THREE] = 0;
validBitTail = 0b0111;
} else if (this->remainListNum == MERGE_LIST_FOUR) {
validBitTail = 0b1111;
} else {
validBitTail = 0b0001;
}
}
__aicore__ inline void MoeV2MrgsortOut::CopyIn() {
this->remainListNum = 0;
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
for (int64_t i = 0, j = 0; i < listNum; i++) {
lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]);
if (lengths[i] > 0) {
DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen<float>(lengths[i]), sizeof(float)));
tmpUbInputs[j] = this->ubInputs[i];
elementCountListTail[j] = lengths[i];
this->remainListNum += 1;
j++;
}
}
}
__aicore__ inline void MoeV2MrgsortOut::MrgsortCompute() {
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
if (this->remainListNum == MERGE_LIST_TWO) {
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]);
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else if (this->remainListNum == MERGE_LIST_THREE) {
MrgSortSrcList sortListTail =
MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]);
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else if (this->remainListNum == MERGE_LIST_FOUR) {
MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO],
tmpUbInputs[MERGE_LIST_IDX_THREE]);
MrgSort<float, true>(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1);
} else {
DataCopy(this->tempBuffer, this->tmpUbInputs[0], Align(GetSortLen<float>(elementCountListTail[0]), sizeof(float)));
listSortedNums[0] = elementCountListTail[0];
}
}
__aicore__ inline void MoeV2MrgsortOut::UpdateSortInfo() {
curLoopSortedNum = 0;
for (int64_t i = 0, j = 0; i < listNum; i++) {
if (lengths[i] > 0) {
// update remain size
listRemainElements[i] -= listSortedNums[j];
allRemainElements -= listSortedNums[j];
// update offset
offsets[i] += GetSortOffset<float>(listSortedNums[j]);
// update current loop sorted nums
curLoopSortedNum += listSortedNums[j];
j += 1;
}
}
}
__aicore__ inline void MoeV2MrgsortOut::Extract() {
AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM));
pipe_barrier(PIPE_V);
Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float)));
pipe_barrier(PIPE_V);
Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float)));
}
__aicore__ inline void MoeV2MrgsortOut::CopyOut() {
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = curLoopSortedNum * sizeof(int32_t);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams);
DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams);
outOffset += curLoopSortedNum;
}
__aicore__ inline void MoeV2MrgsortOut::Init(MoeV2MrgsortParam* param, TPipe* tPipe) {
this->param = param;
this->allRemainElements = 0;
for (int64_t i = 0; i < listNum; i++) {
offsets[i] = GetSortOffset<float>(param->perListElements * i);
if (i == listNum - 1) {
listRemainElements[i] = param->lastListElements;
} else {
listRemainElements[i] = param->perListElements;
}
allRemainElements += listRemainElements[i];
}
}
__aicore__ inline void MoeV2MrgsortOut::Process() {
for (; allRemainElements > 0;) {
CopyIn();
UpdateMrgParam();
MrgsortCompute();
UpdateSortInfo();
Extract();
CopyOut();
}
ClearCache();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_MRGSORT_OUT_H

View File

@@ -0,0 +1,70 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_sort_base.h
* \brief
*/
#ifndef INNER_MOE_V2_SORT_BASE_H
#define INNER_MOE_V2_SORT_BASE_H
#include "kernel_operator.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2SortBase {
public:
__aicore__ inline MoeV2SortBase(){};
protected:
__aicore__ inline void SyncAll();
protected:
TPipe* pipe;
TQue<QuePosition::VECIN, 1> sortDataCopyInQueue;
TQue<QuePosition::VECOUT, 1> sortDataCopyOutQueue;
TBuf<TPosition::VECCALC> tempBuffer;
TBuf<TPosition::VECCALC> sortedBuffer;
GlobalTensor<int32_t> expertIdxGm;
GlobalTensor<int32_t> sortedexpertIdxGm;
GlobalTensor<int32_t> expandDstToSrcRowGm;
GlobalTensor<int32_t> expertTokensCountOrCumsumGm;
GlobalTensor<int32_t> expertTokensBeforeCapacityGm;
int64_t tileLength;
int64_t bufferNum = 1;
int64_t totalLength;
int64_t coreNum;
int64_t n;
int64_t k;
int64_t existRowIdx;
int64_t expertNum;
int64_t expertTokensCountOrCumsumFlag = 0;
int64_t expertTokensBeforeCapacityFlag = 0;
static constexpr int64_t SYNC_GM_NUM = 2;
static constexpr int64_t WORK_GM_NUM = 2;
static constexpr int64_t DST_BLK_STRIDE = 1;
static constexpr int64_t DST_REP_STRIDE = 8;
};
__aicore__ inline void MoeV2SortBase::SyncAll() {
if (coreNum == 1) {
return;
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_SORT_BASE_H

View File

@@ -0,0 +1,373 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_sort_multi_core.h
* \brief
*/
#ifndef INNER_MOE_V2_VBS_ONE_CORE_H
#define INNER_MOE_V2_VBS_ONE_CORE_H
#include "moe_v2_sort_base.h"
#include "moe_v2_mrgsort.h"
#include "moe_v2_mrgsort_out.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2SortMultiCore : public MoeV2SortBase {
public:
__aicore__ inline MoeV2SortMultiCore(){};
template <typename TilingData>
__aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void VBSProcess();
__aicore__ inline void UBSortProcess(int64_t progress, int64_t size, int64_t sortNum);
__aicore__ inline void OneCoreVMSProcess(int64_t listNum, int64_t perListElements, int64_t lastListElements);
__aicore__ inline void VMSProcess();
__aicore__ inline void SortOutProcess();
__aicore__ inline void VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum);
__aicore__ inline void UBSortCompute(int64_t progress, int64_t size, int64_t sortNum);
__aicore__ inline void VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum);
__aicore__ inline void InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset, int64_t loopOffset);
__aicore__ inline void InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum, int64_t coreOffset);
__aicore__ inline void InitExpertTokensGlobalMemory();
private:
GlobalTensor<float> workspaceGms[2];
const InnerMoeV2VBSComputeTilingData* vbsTilingData;
const InnerMoeV2VMSMiddleComputeTilingData* vmsTilingData;
const InnerMoeV2SortOutComputeTilingData* sortOutTilingData;
// for MoeMrgsort
MoeV2Mrgsort mrgsorter;
MoeV2MrgsortParam mrgsortParam;
int64_t coreNum;
int64_t blockIdx;
int64_t srcWsIndex = 0;
int64_t listNum;
int64_t perListElements;
int64_t lastListElements;
int64_t sortTotalLength;
int64_t sortCoreLoops;
int64_t sortCoreLoopElements;
int64_t sortCoreLastLoopElements;
int64_t perCoreExpert;
int64_t needInitExpertCore;
int64_t currentCoreExpert;
static constexpr int64_t MAX_MRGSORT_LIST = 4;
};
__aicore__ inline void MoeV2SortMultiCore::InitExpertTokensGlobalMemory() {
if (this->blockIdx < this->needInitExpertCore) {
if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
InitGlobalMemory(expertTokensCountOrCumsumGm, currentCoreExpert, 0);
}
if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
InitGlobalMemory(expertTokensBeforeCapacityGm, currentCoreExpert, 0);
}
}
}
__aicore__ inline void MoeV2SortMultiCore::VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum) {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
int64_t inOffset = progress * sortCoreLoopElements;
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(size * sizeof(int32_t)), 0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams);
LocalTensor<int32_t> rowIdxLocal = inLocal[sortNum];
int64_t startValue = this->blockIdx * this->vbsTilingData->perCoreElements + inOffset;
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
ArithProgression<int32_t>(rowIdxLocal, startValue, 1, size);
sortDataCopyInQueue.EnQue(inLocal);
}
__aicore__ inline void MoeV2SortMultiCore::UBSortCompute(int64_t progress, int64_t size, int64_t sortNum) {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertForSourceRowLocal = inLocal[0];
LocalTensor<float> expertForSourceRowLocalFp32;
expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast<float>();
Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, sortNum);
pipe_barrier(PIPE_V);
Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, sortNum);
pipe_barrier(PIPE_V);
int64_t duplicateNum = size % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
int duplicateIndex = size - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
LocalTensor<float> concatLocal = expertForSourceRowLocalFp32;
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(sortNum));
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
LocalTensor<uint32_t> sourceRowLocal;
sourceRowLocal = inLocal[sortNum].ReinterpretCast<uint32_t>();
Sort<float, true>(outLocal, concatLocal, sourceRowLocal, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM);
sortDataCopyOutQueue.EnQue<float>(outLocal);
sortDataCopyInQueue.FreeTensor(inLocal);
}
__aicore__ inline void MoeV2SortMultiCore::VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum) {
LocalTensor<float> outLocal = sortDataCopyOutQueue.DeQue<float>();
DataCopy(workspaceGms[0][this->blockIdx * GetSortLen<float>(this->vbsTilingData->perCoreElements) +
GetSortLen<float>(progress * sortCoreLoopElements)],
outLocal, Align(GetSortLen<float>(size), sizeof(float)));
sortDataCopyOutQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset,
int64_t loopOffset) {
GlobalTensor<float> srcWsGm = workspaceGms[srcWsIndex][blockIdx * coreOffset + loopOffset];
LocalTensor<float> inLocal = sortDataCopyInQueue.AllocTensor<float>();
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
for (int64_t i = 0; i < listNum; i++) {
LocalTensor<float> inLocalT = inLocal[GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * i];
sorter->SetInput(srcWsGm, inLocalT);
}
GlobalTensor<float> dstWsGm = workspaceGms[1 - srcWsIndex][blockIdx * coreOffset + loopOffset];
sorter->SetOutput(dstWsGm, outLocal);
sortDataCopyInQueue.FreeTensor(inLocal);
sortDataCopyOutQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum,
int64_t coreOffset) {
GlobalTensor<float> srcWsGm = workspaceGms[srcWsIndex];
LocalTensor<float> inLocal = sortDataCopyInQueue.AllocTensor<float>();
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
for (int64_t i = 0; i < listNum; i++) {
LocalTensor<float> inLocalT = inLocal[GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * i];
sorter->SetInput(srcWsGm, inLocalT);
}
LocalTensor<float> outLocalV = outLocal[this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST];
sorter->SetOutput(this->sortedexpertIdxGm, this->expandDstToSrcRowGm, outLocal, outLocalV);
LocalTensor<float> tempBuffer =
sortedBuffer.Get<float>(GetSortLen<float>(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST);
sorter->SetBuffer(tempBuffer);
sortDataCopyInQueue.FreeTensor(inLocal);
sortDataCopyOutQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2SortMultiCore::OneCoreVMSProcess(int64_t listNum, int64_t perListElements,
int64_t lastListElements) {
int64_t coreOffset = GetSortLen<float>(this->vbsTilingData->perCoreElements);
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
for (int64_t i = 0; listNum >= 1; i++) {
int64_t loops = (listNum + MAX_MRGSORT_LIST - 1) / MAX_MRGSORT_LIST;
int64_t remainListNum = listNum - (loops - 1) * MAX_MRGSORT_LIST;
mrgsortParam.perListElements = perListElements;
mrgsortParam.lastListElements = perListElements;
int64_t loopOffset = GetSortLen<float>(mrgsortParam.perListElements * MAX_MRGSORT_LIST);
for (int64_t loop = 0; loop < loops - 1; loop++) {
InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, loop * loopOffset);
mrgsorter.Init(&mrgsortParam);
mrgsorter.Process();
}
mrgsortParam.perListElements = perListElements;
mrgsortParam.lastListElements = lastListElements;
InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, (loops - 1) * loopOffset);
mrgsorter.Init(&mrgsortParam);
mrgsorter.Process();
listNum = loops;
lastListElements = perListElements * (remainListNum - 1) + lastListElements;
perListElements = perListElements * MAX_MRGSORT_LIST;
srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM;
if (loops == 1) {
break;
}
}
}
__aicore__ inline void MoeV2SortMultiCore::UBSortProcess(int64_t progress, int64_t size, int64_t sortNum) {
VBSCopyIn(progress, size, sortNum);
UBSortCompute(progress, size, sortNum);
VBSCopyOut(progress, size, sortNum);
}
__aicore__ inline void MoeV2SortMultiCore::VBSProcess() {
if (this->blockIdx < this->vbsTilingData->needCoreNum) {
int64_t sortNum = Ceil(sortCoreLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
for (int64_t loop = 0; loop < sortCoreLoops - 1; loop++) {
UBSortProcess(loop, sortCoreLoopElements, sortNum);
}
sortNum = Ceil(sortCoreLastLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
UBSortProcess(sortCoreLoops - 1, sortCoreLastLoopElements, sortNum);
if (sortCoreLoops > 1) {
OneCoreVMSProcess(sortCoreLoops, sortCoreLoopElements, sortCoreLastLoopElements);
}
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
__aicore__ inline void MoeV2SortMultiCore::VMSProcess() {
int64_t currentStageNeedCoreNum = this->vmsTilingData->needCoreNum;
perListElements = this->vbsTilingData->perCoreElements;
lastListElements = this->vbsTilingData->lastCoreElements;
listNum = this->vbsTilingData->needCoreNum;
for (; listNum > MAX_MRGSORT_LIST;) {
currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST);
int64_t coreOffset = GetSortLen<float>(perListElements * MAX_MRGSORT_LIST);
int64_t remainListNum = listNum - (currentStageNeedCoreNum - 1) * MAX_MRGSORT_LIST;
if (this->blockIdx < currentStageNeedCoreNum - 1) {
mrgsortParam.perListElements = perListElements;
mrgsortParam.lastListElements = perListElements;
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, 0);
mrgsorter.Init(&mrgsortParam);
mrgsorter.Process();
} else if (this->blockIdx == currentStageNeedCoreNum - 1) {
mrgsortParam.perListElements = perListElements;
mrgsortParam.lastListElements = lastListElements;
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, 0);
mrgsorter.Init(&mrgsortParam);
mrgsorter.Process();
}
listNum = currentStageNeedCoreNum;
currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST);
srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM;
lastListElements = perListElements * (remainListNum - 1) + lastListElements;
perListElements = perListElements * MAX_MRGSORT_LIST;
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
}
__aicore__ inline void MoeV2SortMultiCore::SortOutProcess() {
if (this->blockIdx < 1) {
mrgsortParam.perListElements = perListElements;
mrgsortParam.lastListElements = lastListElements;
mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements;
MoeV2MrgsortOut sorter;
InitMoeMrgSortOut(&sorter, listNum, GetSortLen<float>(perListElements));
sorter.Init(&mrgsortParam, pipe);
sorter.Process();
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
template <typename TilingData>
__aicore__ inline void MoeV2SortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum,
GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace,
const TilingData* tilingData, TPipe* tPipe) {
this->totalLength = tilingData->n * tilingData->k;
this->coreNum = tilingData->coreNum;
this->vbsTilingData = &(tilingData->vbsComputeParamsOp);
this->vmsTilingData = &(tilingData->vmsMiddleComputeParamsOp);
this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp);
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->tileLength = this->vbsTilingData->perCorePerLoopElements;
this->sortTotalLength = this->vbsTilingData->perCoreElements;
if (this->blockIdx == tilingData->vbsComputeParamsOp.needCoreNum - 1) {
this->tileLength = this->vbsTilingData->lastCorePerLoopElements;
this->sortTotalLength = this->vbsTilingData->lastCoreElements;
}
this->n = tilingData->n;
this->k = tilingData->k;
this->expertNum = tilingData->expertNum;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
// VBS param init
if (this->blockIdx == this->vbsTilingData->needCoreNum - 1) {
sortCoreLoops = this->vbsTilingData->lastCoreLoops;
sortCoreLoopElements = this->vbsTilingData->lastCorePerLoopElements;
sortCoreLastLoopElements = this->vbsTilingData->lastCoreLastLoopElements;
} else {
sortCoreLoops = this->vbsTilingData->perCoreLoops;
sortCoreLoopElements = this->vbsTilingData->perCorePerLoopElements;
sortCoreLastLoopElements = this->vbsTilingData->perCoreLastLoopElements;
}
this->pipe = tPipe;
expertIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)expertIdx + this->blockIdx * tilingData->vbsComputeParamsOp.perCoreElements,
this->sortTotalLength);
sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace),
Align(this->totalLength, sizeof(int32_t)));
expandDstToSrcRowGm.SetGlobalBuffer(
reinterpret_cast<__gm__ int32_t*>(workspace) + Align(this->totalLength, sizeof(int32_t)),
Align(this->totalLength, sizeof(int32_t)));
this->perCoreExpert = Align((this->expertNum + this->coreNum - 1) / this->coreNum, sizeof(int32_t));
this->needInitExpertCore = (this->expertNum + this->perCoreExpert - 1) / this->perCoreExpert;
this->currentCoreExpert = this->perCoreExpert;
if (this->blockIdx == needInitExpertCore - 1) {
this->currentCoreExpert = this->expertNum - (this->needInitExpertCore - 1) * this->perCoreExpert;
}
if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) {
expertTokensCountOrCumsumGm.SetGlobalBuffer(
(__gm__ int32_t*)expertTokensCountOrCumsum + this->blockIdx * this->perCoreExpert, this->currentCoreExpert);
}
if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) {
expertTokensBeforeCapacityGm.SetGlobalBuffer(
(__gm__ int32_t*)expertTokensBeforeCapacity + this->blockIdx * this->perCoreExpert, this->currentCoreExpert);
}
// key and value
int64_t kvFactor = 2;
workspaceGms[0].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2,
Align(this->totalLength, sizeof(int32_t)) * kvFactor);
workspaceGms[1].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * (kvFactor + 2),
Align(this->totalLength, sizeof(int32_t)) * kvFactor);
int64_t bufferSize = Ceil(Max(this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST, sortCoreLoopElements),
ONE_REPEAT_SORT_NUM) *
ONE_REPEAT_SORT_NUM * sizeof(int32_t) * kvFactor;
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize);
pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize);
pipe->InitBuffer(sortedBuffer, bufferSize);
}
__aicore__ inline void MoeV2SortMultiCore::Process() {
InitExpertTokensGlobalMemory();
VBSProcess();
VMSProcess();
SortOutProcess();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_VBS_ONE_CORE_H

View File

@@ -0,0 +1,162 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_sort_one_core.h
* \brief
*/
#ifndef INNER_MOE_V2_SORT_ONE_CORE_H
#define INNER_MOE_V2_SORT_ONE_CORE_H
#include "moe_v2_mrgsort.h"
#include "moe_v2_sort_base.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2SortOneCore : public MoeV2SortBase {
public:
__aicore__ inline MoeV2SortOneCore(){};
template <typename TilingData>
__aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity,
GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn();
__aicore__ inline void SortCompute();
__aicore__ inline void CopyOut();
private:
int64_t sortNum;
int64_t blockIdx;
};
__aicore__ inline void MoeV2SortOneCore::CopyIn() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.AllocTensor<int32_t>();
DataCopyExtParams dataCopyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(this->totalLength * sizeof(int32_t)),
0, 0, 0};
DataCopyPadExtParams<int32_t> dataCopyPadParams{false, 0, 0, 0};
DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams);
LocalTensor<int32_t> rowIdxLocal = inLocal[this->sortNum];
ArithProgression<int32_t>(rowIdxLocal, 0, 1, this->sortNum);
sortDataCopyInQueue.EnQue(inLocal);
}
__aicore__ inline void MoeV2SortOneCore::SortCompute() {
LocalTensor<int32_t> inLocal = sortDataCopyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> expertForSourceRowLocal = inLocal[0];
LocalTensor<float> expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast<float>();
Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength);
pipe_barrier(PIPE_V);
Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, this->tileLength);
pipe_barrier(PIPE_V);
int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM;
if (duplicateNum > 0) {
int duplicateIndex = this->totalLength - duplicateNum;
uint64_t mask0 = UINT64_MAX;
mask0 = mask0 << duplicateNum;
mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM);
uint64_t mask[2] = {mask0, 0};
Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE);
pipe_barrier(PIPE_V);
}
LocalTensor<float> concatLocal;
LocalTensor<float> tempTensor = tempBuffer.Get<float>(GetSortLen<float>(this->sortNum));
Concat(concatLocal, expertForSourceRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<float> sortedLocal = sortedBuffer.Get<float>(GetSortLen<float>(this->sortNum));
LocalTensor<uint32_t> sourceRowLocal;
sourceRowLocal = inLocal[this->sortNum].ReinterpretCast<uint32_t>();
Sort<float, true>(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
LocalTensor<float> outLocal = sortDataCopyOutQueue.AllocTensor<float>();
LocalTensor<float> sortedExpertForSourceRowLocal = outLocal[0];
LocalTensor<uint32_t> expandDstToSrcRowLocal;
expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast<uint32_t>();
Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM);
pipe_barrier(PIPE_V);
Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength);
pipe_barrier(PIPE_V);
LocalTensor<int32_t> expertForSourceRowLocalInt32;
expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast<int32_t>();
Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength);
sortDataCopyOutQueue.EnQue<float>(outLocal);
sortDataCopyInQueue.FreeTensor(inLocal);
}
__aicore__ inline void MoeV2SortOneCore::CopyOut() {
LocalTensor<int32_t> outLocal = sortDataCopyOutQueue.DeQue<int32_t>();
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = this->totalLength * sizeof(int32_t);
DataCopyPad(sortedexpertIdxGm, outLocal[0], intriParams);
DataCopyPad(expandDstToSrcRowGm, outLocal[this->sortNum], intriParams);
sortDataCopyOutQueue.FreeTensor(outLocal);
}
template <typename TilingData>
__aicore__ inline void MoeV2SortOneCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum,
GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace,
const TilingData* tilingData, TPipe* tPipe) {
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t));
this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
this->totalLength = tilingData->n * tilingData->k;
this->coreNum = tilingData->coreNum;
this->pipe = tPipe;
this->n = tilingData->n;
this->k = tilingData->k;
this->expertNum = tilingData->expertNum;
this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag;
this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag;
expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength);
sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace), this->tileLength);
expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace) + this->tileLength,
this->tileLength);
if (this->blockIdx == this->coreNum - 1) {
if (this->expertTokensCountOrCumsumFlag > 0) {
expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum,
Align(this->expertNum, sizeof(int32_t)));
InitGlobalMemory(expertTokensCountOrCumsumGm, this->expertNum, 0);
}
if (this->expertTokensBeforeCapacityFlag == 1) {
expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity,
Align(this->expertNum, sizeof(int32_t)));
InitGlobalMemory(expertTokensBeforeCapacityGm, this->expertNum, 0);
}
}
// key and value
int64_t kvFactor = 2;
int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor;
pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize);
pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize);
pipe->InitBuffer(tempBuffer, buffSize);
pipe->InitBuffer(sortedBuffer, buffSize);
}
__aicore__ inline void MoeV2SortOneCore::Process() {
if (get_block_idx() + get_subblockid() * get_block_num() < 1) {
CopyIn();
SortCompute();
CopyOut();
}
this->SyncAll();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_SORT_ONE_CORE_H

View File

@@ -0,0 +1,560 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_src_to_dst_and_gather.h
* \brief
*/
#ifndef MOE_V2_SRC_TO_DST_AND_GATHER_H
#define MOE_V2_SRC_TO_DST_AND_GATHER_H
#include "moe_v2_common.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T, typename TilingData>
class MoeV2SrcToDstAndGather {
public:
__aicore__ inline MoeV2SrcToDstAndGather(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, GM_ADDR expandedX,
GM_ADDR dynamicQuantScale, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t progress);
__aicore__ inline void CopyOut(int64_t progress);
__aicore__ inline void CopyOutLoops(int64_t progress);
__aicore__ inline void Compute(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx);
__aicore__ inline float ComputeMax(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal,
LocalTensor<float>& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx,
int64_t j);
__aicore__ inline void ComputeScale(LocalTensor<float>& inLocal, LocalTensor<float>& tempLocal, float scaleTemp,
int64_t dstIndex, int64_t j);
__aicore__ inline void ComputeLoops(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx);
__aicore__ inline void CopyOutRemain();
__aicore__ inline void SyncAll();
__aicore__ inline void AssistInit();
private:
TPipe* pipe;
TQue<QuePosition::VECIN, 1> copyInQueue;
TQue<QuePosition::VECOUT, 1> copyOutQueue;
TQue<QuePosition::VECOUT, 1> copyOutZeroQueue;
TQue<QuePosition::VECIN, 1> inputXInQueue;
TQue<QuePosition::VECIN, 1> smoothInQueue;
TQue<QuePosition::VECOUT, 1> calcQueue;
TQue<QuePosition::VECOUT, 1> inputXOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutQueue;
TQue<QuePosition::VECOUT, 1> scaleOutZeroQueue;
GlobalTensor<int32_t> expandDstToSrcRowGm;
GlobalTensor<int32_t> expandedRowIdxGm;
GlobalTensor<int32_t> expertIdxValueGm;
GlobalTensor<int32_t> expandedExpertIdxGm;
GlobalTensor<int8_t> expandedXGm;
GlobalTensor<T> inputXGm;
GlobalTensor<float> quantSmoothGm;
GlobalTensor<float> dynamicQuantScaleGm;
GlobalTensor<float> quantSrcGm;
LocalTensor<int8_t> outTmpLocal;
LocalTensor<float> scaleOutTmpLocal;
LocalTensor<float> smoothLocal;
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
int64_t coreNum;
int64_t blockIdx;
int64_t totalLength;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t rowLoops;
int64_t expertCapacity;
int64_t expertNum;
int64_t cols;
int64_t perLoopCols;
int64_t lastLoopCols;
int64_t colLoops;
int64_t perLoopColsAlign;
int64_t k;
int64_t colsTileLength;
int64_t smoothType;
int64_t tokenCount = 0;
int32_t lastExpertId = -1;
int32_t lastCoreExpertId = 0;
int32_t lastCoreExpertIdNum = 0;
};
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::AssistInit() {
LocalTensor<int16_t> outLocal = copyOutZeroQueue.AllocTensor<int16_t>();
Duplicate<int16_t>(outLocal, static_cast<int16_t>(0), this->perLoopCols);
copyOutZeroQueue.EnQue<int16_t>(outLocal);
LocalTensor<float> scaleOutLocal = scaleOutZeroQueue.AllocTensor<float>();
Duplicate<float>(scaleOutLocal, 0.0f, 8);
scaleOutZeroQueue.EnQue<float>(scaleOutLocal);
if (this->blockIdx != 0) {
this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2);
this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1);
for (int64_t i = this->blockIdx - 2; i >= 0; i--) {
int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2);
if (lastExpertIdx < this->lastCoreExpertId) {
break;
}
int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1);
this->lastCoreExpertIdNum += lastExpertNum;
}
}
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyIn(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
int64_t length = Align(currentLoopRows, sizeof(int32_t));
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length);
DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length);
copyInQueue.EnQue<int32_t>(inLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Compute(int32_t srcIdx, int32_t dstIdx,
int32_t expertIdx) {
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(this->cols * sizeof(T)), 0, 0, 0};
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
if constexpr (IsSameType<T, float>::value) {
DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal.template ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols],
copyInParams, {false, 0, 0, 0});
}
if (smoothType == 2) {
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0});
}
inputXInQueue.EnQue<float>(inLocal);
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
inLocal = inputXInQueue.DeQue<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
LocalTensor<float> dynamicQuantLocal = scaleOutQueue.AllocTensor<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.template ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols);
pipe_barrier(PIPE_V);
}
if (smoothType != 0) {
Mul(inLocal, inLocal, smoothLocal, this->cols);
pipe_barrier(PIPE_V);
}
Abs(tempLocal, inLocal, this->cols);
pipe_barrier(PIPE_V);
ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols);
pipe_barrier(PIPE_V);
float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f;
Duplicate<float>(dynamicQuantLocal, maxValue, 8);
Duplicate<float>(tempLocal, maxValue, this->cols);
pipe_barrier(PIPE_V);
Div(tempLocal, inLocal, tempLocal, this->cols);
pipe_barrier(PIPE_V);
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, this->cols);
pipe_barrier(PIPE_V);
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, this->cols);
calcQueue.FreeTensor(tempLocal);
inputXOutQueue.EnQue(outLocal);
scaleOutQueue.EnQue(dynamicQuantLocal);
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0});
outLocal = inputXOutQueue.DeQue<int8_t>();
#ifndef __CCE_KT_TEST__
DataCopyPad(expandedXGm[dstIdx * this->cols], outLocal, copyOutParams);
#endif
inputXInQueue.FreeTensor(inLocal);
inputXOutQueue.FreeTensor(outLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOut(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
int64_t length = Align(currentLoopRows, sizeof(int32_t));
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(this->cols * sizeof(int8_t)), 0, 0, 0};
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
if (this->lastExpertId == -1) {
this->lastExpertId = this->lastCoreExpertId;
this->tokenCount = this->lastCoreExpertIdNum;
}
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
int32_t expertIdx = inLocal[length].GetValue(idx);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
int32_t index = 0;
while (this->lastExpertId < expertIdx) {
while (this->tokenCount < this->expertCapacity) {
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
DataCopyPad(expandedXGm[index * this->cols], this->outTmpLocal, copyParams1);
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
this->tokenCount++;
}
this->tokenCount = 0;
this->lastExpertId++;
}
if (this->tokenCount < this->expertCapacity) {
int32_t outOffset = inLocal.GetValue(idx);
index = expertIdx * this->expertCapacity + this->tokenCount;
outLocal.SetValue(0, index);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
Compute(outOffset, index, expertIdx);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
this->tokenCount++;
}
}
copyInQueue.FreeTensor(inLocal);
copyOutQueue.FreeTensor(outLocal);
}
template <typename T, typename TilingData>
__aicore__ inline float MoeV2SrcToDstAndGather<T, TilingData>::ComputeMax(LocalTensor<float>& inLocal,
LocalTensor<float>& tempLocal,
LocalTensor<float>& dynamicQuantLocal,
int32_t srcIdx, int32_t expertIdx,
int64_t j) {
LocalTensor<float> smoothLocal = smoothInQueue.AllocTensor<float>();
DataCopyExtParams intriParamsT{1, static_cast<uint32_t>(colsTileLength * sizeof(T)), 0, 0, 0};
DataCopyExtParams intriParamsFp32{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
if constexpr (!IsSameType<T, float>::value) {
DataCopyPad(inLocal.ReinterpretCast<T>()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols],
intriParamsT, {false, 0, 0, 0});
} else {
DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0});
}
inputXInQueue.EnQue<float>(inLocal);
inLocal = inputXInQueue.DeQue<float>();
if constexpr (!IsSameType<T, float>::value) {
Cast(inLocal, inLocal.ReinterpretCast<T>()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength);
pipe_barrier(PIPE_V);
}
if (smoothType != 0) {
DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32,
{false, 0, 0, 0});
smoothInQueue.EnQue(smoothLocal);
smoothLocal = smoothInQueue.DeQue<float>();
Mul(inLocal, inLocal, smoothLocal, colsTileLength);
pipe_barrier(PIPE_V);
}
Abs(tempLocal, inLocal, colsTileLength);
pipe_barrier(PIPE_V);
ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength);
DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32);
smoothInQueue.FreeTensor(smoothLocal);
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
return dynamicQuantLocal.GetValue(8);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::ComputeScale(LocalTensor<float>& inLocal,
LocalTensor<float>& tempLocal,
float scaleTemp, int64_t dstIndex,
int64_t j) {
DataCopyExtParams copyInParams{1, static_cast<uint32_t>(colsTileLength * sizeof(float)), 0, 0, 0};
DataCopyExtParams copyOutParams{1, static_cast<uint32_t>(colsTileLength * sizeof(int8_t)), 0, 0, 0};
LocalTensor<int8_t> outLocal = inputXOutQueue.AllocTensor<int8_t>();
DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0});
inputXInQueue.EnQue<float>(inLocal);
inLocal = inputXInQueue.DeQue<float>();
Duplicate<float>(tempLocal, scaleTemp, colsTileLength);
pipe_barrier(PIPE_V);
Div(tempLocal, inLocal, tempLocal, colsTileLength);
pipe_barrier(PIPE_V);
Cast(tempLocal.ReinterpretCast<half>(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength);
pipe_barrier(PIPE_V);
Cast(outLocal, tempLocal.ReinterpretCast<half>(), RoundMode::CAST_ROUND, colsTileLength);
inputXOutQueue.EnQue(outLocal);
outLocal = inputXOutQueue.DeQue<int8_t>();
DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams);
inputXOutQueue.FreeTensor(outLocal);
SetWaitFlag<HardEvent::MTE3_MTE2>(HardEvent::MTE3_MTE2);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::ComputeLoops(int32_t srcIdx, int32_t dstIdx,
int32_t expertIdx) {
LocalTensor<float> inLocal = inputXInQueue.AllocTensor<float>();
LocalTensor<float> tempLocal = calcQueue.AllocTensor<float>();
LocalTensor<float> quantScaleLocal = scaleOutQueue.AllocTensor<float>();
uint32_t tmp = 0xFF7FFFFF;
float reduceMax = *((float*)&tmp);
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j);
reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax;
}
float scaleTemp = reduceMax / 127.0f;
Duplicate<float>(quantScaleLocal, scaleTemp, 8);
scaleOutQueue.EnQue(quantScaleLocal);
quantScaleLocal = scaleOutQueue.DeQue<float>();
DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0});
for (int64_t j = 0; j < this->colLoops; j++) {
colsTileLength = this->perLoopCols;
if (j == this->colLoops - 1) {
colsTileLength = this->lastLoopCols;
}
ComputeScale(inLocal, tempLocal, scaleTemp, dstIdx, j);
}
inputXInQueue.FreeTensor(inLocal);
calcQueue.FreeTensor(tempLocal);
scaleOutQueue.FreeTensor(quantScaleLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOutLoops(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
int64_t length = Align(currentLoopRows, sizeof(int32_t));
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
if (this->lastExpertId == -1) {
this->lastExpertId = this->lastCoreExpertId;
this->tokenCount = this->lastCoreExpertIdNum;
}
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
int32_t expertIdx = inLocal[length].GetValue(idx);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
int32_t index = 0;
while (this->lastExpertId < expertIdx) {
while (this->tokenCount < this->expertCapacity) {
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
int64_t col = this->perLoopCols;
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
for (int64_t i = 0; i < this->colLoops; i++) {
if (i == this->colLoops - 1) {
col = this->lastLoopCols;
}
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(int8_t)), 0, 0, 0};
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
}
this->tokenCount++;
}
this->tokenCount = 0;
this->lastExpertId++;
}
if (this->tokenCount < this->expertCapacity) {
int32_t outOffset = inLocal.GetValue(idx);
index = expertIdx * this->expertCapacity + this->tokenCount;
outLocal.SetValue(0, index);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
if (smoothType == 2) {
ComputeLoops(outOffset, index, expertIdx);
} else {
ComputeLoops(outOffset, index, 0);
}
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
this->tokenCount++;
}
}
copyInQueue.FreeTensor(inLocal);
copyOutQueue.FreeTensor(outLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::CopyOutRemain() {
if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) {
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal);
return;
}
while (this->lastExpertId < this->expertNum) {
while (this->tokenCount < this->expertCapacity) {
int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount;
int64_t col = this->perLoopCols;
DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0});
for (int64_t i = 0; i < this->colLoops; i++) {
if (i == this->colLoops - 1) {
col = this->lastLoopCols;
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(int8_t)), 0, 0, 0};
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
}
this->tokenCount++;
}
this->tokenCount = 0;
this->lastExpertId++;
}
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx,
GM_ADDR expandedX, GM_ADDR dynamicQuantScale,
GM_ADDR workspace, const TilingData* tilingData,
TPipe* tPipe) {
int64_t blockNum = GetBlockNum();
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->coreNum = tilingData->coreNum;
this->totalLength = tilingData->n * tilingData->k;
this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp);
this->expertNum = tilingData->expertNum;
this->expertCapacity = tilingData->expertCapacity;
this->cols = tilingData->cols;
this->k = tilingData->k;
this->smoothType = tilingData->smoothType;
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
this->coreRows = this->srcToDstTilingData->lastCoreRows;
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
this->rowLoops = this->srcToDstTilingData->lastCoreLoops;
} else {
this->coreRows = this->srcToDstTilingData->perCoreRows;
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
this->rowLoops = this->srcToDstTilingData->perCoreLoops;
}
this->perLoopCols = this->srcToDstTilingData->perLoopCols;
this->lastLoopCols = this->srcToDstTilingData->lastLoopCols;
this->colLoops = this->srcToDstTilingData->colLoops;
this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T));
inputXGm.SetGlobalBuffer((__gm__ T*)x);
quantSmoothGm.SetGlobalBuffer((__gm__ float*)scale);
dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale);
int64_t length = Align(this->totalLength, sizeof(int32_t));
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length);
expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX, this->expertNum * this->expertCapacity * this->cols);
expandedExpertIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
expandDstToSrcRowGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2);
if (this->colLoops > 1) {
quantSrcGm.SetGlobalBuffer((__gm__ float*)workspace + length * 2 + this->coreNum * 2 + this->blockIdx * this->cols,
this->cols * sizeof(float));
}
pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2);
pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t)));
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t)));
int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T));
perLoopColsAlignBytes =
Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES));
pipe->InitBuffer(inputXInQueue, 1, perLoopColsAlignBytes);
pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float)));
pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t)));
pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES);
pipe->InitBuffer(scaleOutZeroQueue, 1, BLOCK_BYTES);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstAndGather<T, TilingData>::Process() {
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
AssistInit();
this->outTmpLocal = copyOutZeroQueue.DeQue<int8_t>();
this->scaleOutTmpLocal = scaleOutZeroQueue.DeQue<float>();
currentLoopRows = perLoopRows;
if (colLoops > 1) {
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
if (loop == this->rowLoops - 1) {
currentLoopRows = lastLoopRows;
}
CopyIn(loop);
CopyOutLoops(loop);
}
} else {
smoothLocal = smoothInQueue.AllocTensor<float>();
if (smoothType == 1) {
DataCopyExtParams smoothParams{1, static_cast<uint32_t>(this->cols * sizeof(float)), 0, 0, 0};
DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0});
}
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
if (loop == this->rowLoops - 1) {
currentLoopRows = lastLoopRows;
}
CopyIn(loop);
CopyOut(loop);
}
smoothInQueue.FreeTensor(smoothLocal);
}
CopyOutRemain();
}
}
} // namespace MoeInitRoutingQuantV2
#endif // MOE_V2_SRC_TO_DST_AND_GATHER_H

View File

@@ -0,0 +1,164 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_src_to_dst_op.h
* \brief
*/
#ifndef INNER_MOE_V2_SRC_TO_DST_H
#define INNER_MOE_V2_SRC_TO_DST_H
#include "moe_v2_common.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
class MoeV2SrcToDstOp {
public:
__aicore__ inline MoeV2SrcToDstOp(){};
template <typename TilingData>
__aicore__ inline void Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t progress);
__aicore__ inline void Compute(int64_t progress);
__aicore__ inline void CopyOut();
__aicore__ inline void SyncAll();
__aicore__ inline void AssistInit();
private:
TPipe* pipe;
TQue<QuePosition::VECIN, 1> copyInQueue;
TQue<QuePosition::VECOUT, 1> copyOutQueue;
TBuf<TPosition::VECCALC> assistBuffer;
GlobalTensor<int32_t> expandDstToSrcRowGm;
GlobalTensor<int32_t> expandSrcToDstRowGm;
GlobalTensor<int32_t> assistGm;
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
int64_t coreNum;
int64_t blockIdx;
int64_t totalLength;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
};
__aicore__ inline void MoeV2SrcToDstOp::AssistInit() {
#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1
OOMCheckAddrRange(assistGm.GetPhyAddr(), ASSIST_NUM * sizeof(int32_t));
#endif
LocalTensor<int32_t> assistTensor = assistBuffer.Get<int32_t>(ASSIST_NUM);
DataCopy(assistTensor, assistGm, ASSIST_NUM);
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
Adds(assistTensor, assistTensor, (int32_t)(this->blockIdx * this->srcToDstTilingData->perCoreRows), ASSIST_NUM);
}
__aicore__ inline void MoeV2SrcToDstOp::CopyIn(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t)));
copyInQueue.EnQue<int32_t>(inLocal);
}
__aicore__ inline void MoeV2SrcToDstOp::Compute(int64_t progress) {
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
LocalTensor<int32_t> assistTensor = assistBuffer.Get<int32_t>(ASSIST_NUM);
pipe_barrier(PIPE_V);
int64_t loops = Ceil(currentLoopRows, ASSIST_INDEX_NUM);
for (int64_t i = 0; i < loops; i++) {
Adds(outLocal[i * ASSIST_NUM], assistTensor,
static_cast<int32_t>(this->perLoopRows * progress + i * ASSIST_INDEX_NUM), ASSIST_NUM);
}
pipe_barrier(PIPE_V);
copyOutQueue.EnQue<int32_t>(outLocal);
}
__aicore__ inline void MoeV2SrcToDstOp::CopyOut() {
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> outLocal = copyOutQueue.DeQue<int32_t>();
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
DataCopyParams intriParams;
intriParams.blockCount = 1;
intriParams.blockLen = sizeof(int32_t);
uint32_t outOffset;
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
outOffset = inLocal.GetValue(idx);
DataCopyPad(expandSrcToDstRowGm[outOffset], outLocal[idx * INT32_ONE_BLOCK_NUM], intriParams);
}
copyInQueue.FreeTensor(inLocal);
copyOutQueue.FreeTensor(outLocal);
}
__aicore__ inline void MoeV2SrcToDstOp::SyncAll() {
if (coreNum == 1) {
return;
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
template <typename TilingData>
__aicore__ inline void MoeV2SrcToDstOp::Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData,
TPipe* tPipe) {
int64_t blockNum = GetBlockNum();
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->coreNum = tilingData->coreNum;
this->totalLength = tilingData->n * tilingData->k;
this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp);
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
this->coreRows = this->srcToDstTilingData->lastCoreRows;
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
} else {
this->coreRows = this->srcToDstTilingData->perCoreRows;
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
}
expandSrcToDstRowGm.SetGlobalBuffer((__gm__ int32_t*)expandSrcToDstRow, Align(this->totalLength, sizeof(int32_t)));
expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) +
this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
assistGm.SetGlobalBuffer((__gm__ int32_t*)assist, ASSIST_NUM);
pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES);
pipe->InitBuffer(copyOutQueue, 1, Ceil(this->perLoopRows, ASSIST_NUM) * ASSIST_NUM * BLOCK_BYTES);
pipe->InitBuffer(assistBuffer, ASSIST_NUM * sizeof(int32_t));
}
__aicore__ inline void MoeV2SrcToDstOp::Process() {
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows;
currentLoopRows = perLoopRows;
AssistInit();
for (int64_t loop = 0; loop < loops - 1; loop++) {
CopyIn(loop);
Compute(loop);
CopyOut();
}
currentLoopRows = lastLoopRows;
CopyIn(loops - 1);
Compute(loops - 1);
CopyOut();
}
this->SyncAll();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_SRC_TO_DST_H

View File

@@ -0,0 +1,269 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file moe_v2_src_to_dst_with_capacity.h
* \brief
*/
#ifndef INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
#define INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
#include "moe_v2_common.h"
namespace MoeInitRoutingQuantV2 {
using namespace AscendC;
using namespace optiling;
template <typename T, typename TilingData>
class MoeV2SrcToDstWithCapacity {
public:
__aicore__ inline MoeV2SrcToDstWithCapacity(){};
__aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace,
const TilingData* tilingData, TPipe* tPipe);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t progress);
__aicore__ inline void CopyOut(int64_t progress);
__aicore__ inline void CopyOutRemain();
__aicore__ inline void SyncAll();
__aicore__ inline void AssistInit();
private:
TPipe* pipe;
TQue<QuePosition::VECIN, 1> copyInQueue;
TQue<QuePosition::VECOUT, 1> copyOutQueue;
TQue<QuePosition::VECOUT, 1> copyOutZeroQueue;
GlobalTensor<int32_t> expandDstToSrcRowGm;
GlobalTensor<int32_t> expandedRowIdxGm;
GlobalTensor<int32_t> expertIdxValueGm;
GlobalTensor<int32_t> expandedExpertIdxGm;
GlobalTensor<T> expandedXGm;
LocalTensor<T> outTmpLocal;
const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData;
int64_t coreNum;
int64_t blockIdx;
int64_t totalLength;
int64_t currentLoopRows;
int64_t coreRows;
int64_t perLoopRows;
int64_t lastLoopRows;
int64_t rowLoops;
int64_t expertCapacity;
int64_t expertNum;
int64_t cols;
int64_t perLoopCols;
int64_t lastLoopCols;
int64_t colLoops;
int64_t tokenCount = 0;
int32_t lastExpertId = -1;
int32_t lastCoreExpertId = 0;
int32_t lastCoreExpertIdNum = 0;
};
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::AssistInit() {
if constexpr (IsSameType<T, int8_t>::value) {
LocalTensor<int16_t> outLocal = copyOutZeroQueue.AllocTensor<int16_t>();
Duplicate<int16_t>(outLocal, static_cast<int16_t>(0), this->perLoopCols);
copyOutZeroQueue.EnQue<int16_t>(outLocal);
} else {
LocalTensor<T> outLocal = copyOutZeroQueue.AllocTensor<T>();
Duplicate<T>(outLocal, static_cast<T>(0), this->perLoopCols);
copyOutZeroQueue.EnQue<T>(outLocal);
}
if (this->blockIdx != 0) {
this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2);
this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1);
for (int64_t i = this->blockIdx - 2; i >= 0; i--) {
int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2);
if (lastExpertIdx < this->lastCoreExpertId) {
break;
}
int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1);
this->lastCoreExpertIdNum += lastExpertNum;
}
}
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyIn(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.AllocTensor<int32_t>();
int64_t length = Align(currentLoopRows, sizeof(int32_t));
DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length);
DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length);
copyInQueue.EnQue<int32_t>(inLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyOut(int64_t progress) {
LocalTensor<int32_t> inLocal = copyInQueue.DeQue<int32_t>();
LocalTensor<int32_t> outLocal = copyOutQueue.AllocTensor<int32_t>();
int64_t length = Align(currentLoopRows, sizeof(int32_t));
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(sizeof(int32_t)), 0, 0, 0};
SetWaitFlag<HardEvent::MTE2_S>(HardEvent::MTE2_S);
if (this->lastExpertId == -1) {
this->lastExpertId = this->lastCoreExpertId;
this->tokenCount = this->lastCoreExpertIdNum;
}
for (int64_t idx = 0; idx < currentLoopRows; idx++) {
int32_t expertIdx = inLocal[length].GetValue(idx);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
int32_t index = 0;
while (this->lastExpertId < expertIdx) {
while (this->tokenCount < this->expertCapacity) {
index = this->lastExpertId * this->expertCapacity + this->tokenCount;
int64_t col = this->perLoopCols;
for (int64_t i = 0; i < this->colLoops; i++) {
if (i == this->colLoops - 1) {
col = this->lastLoopCols;
}
#ifdef __CCE_KT_TEST__
// CPU孪生调试无法使用多核同步可能导致index为未初始化的脏数据因此需要特殊处理
if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) {
continue;
}
#endif
DataCopyExtParams copyParams1{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(T)), 0, 0, 0};
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
}
this->tokenCount++;
}
this->tokenCount = 0;
this->lastExpertId++;
}
if (this->tokenCount < this->expertCapacity) {
int32_t outOffset = inLocal.GetValue(idx);
index = expertIdx * this->expertCapacity + this->tokenCount;
outLocal.SetValue(0, index);
SetWaitFlag<HardEvent::S_MTE3>(HardEvent::S_MTE3);
DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
this->tokenCount++;
}
}
copyInQueue.FreeTensor(inLocal);
copyOutQueue.FreeTensor(outLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyOutRemain() {
if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) {
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
return;
}
while (this->lastExpertId < this->expertNum) {
while (this->tokenCount < this->expertCapacity) {
int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount;
int64_t col = this->perLoopCols;
for (int64_t i = 0; i < this->colLoops; i++) {
if (i == this->colLoops - 1) {
col = this->lastLoopCols;
}
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(col * sizeof(T)), 0, 0, 0};
DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams);
SetWaitFlag<HardEvent::MTE3_S>(HardEvent::MTE3_S);
}
this->tokenCount++;
}
this->tokenCount = 0;
this->lastExpertId++;
}
copyOutZeroQueue.FreeTensor(this->outTmpLocal);
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::SyncAll() {
if (coreNum == 1) {
return;
}
#ifndef __CCE_KT_TEST__
AscendC::SyncAll();
#endif
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX,
GM_ADDR workspace, const TilingData* tilingData,
TPipe* tPipe) {
int64_t blockNum = GetBlockNum();
this->pipe = tPipe;
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->coreNum = tilingData->coreNum;
this->totalLength = tilingData->n * tilingData->k;
this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp);
this->expertNum = tilingData->expertNum;
this->expertCapacity = tilingData->expertCapacity;
this->cols = tilingData->cols;
if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) {
this->coreRows = this->srcToDstTilingData->lastCoreRows;
this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows;
this->rowLoops = this->srcToDstTilingData->lastCoreLoops;
} else {
this->coreRows = this->srcToDstTilingData->perCoreRows;
this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows;
this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows;
this->rowLoops = this->srcToDstTilingData->perCoreLoops;
}
this->perLoopCols = this->srcToDstTilingData->perLoopCols;
this->lastLoopCols = this->srcToDstTilingData->lastLoopCols;
this->colLoops = this->srcToDstTilingData->colLoops;
int64_t length = Align(this->totalLength, sizeof(int32_t));
expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length);
expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, this->expertNum * this->expertCapacity * this->cols);
expandedExpertIdxGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
expandDstToSrcRowGm.SetGlobalBuffer(
(__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows,
Align(this->coreRows, sizeof(int32_t)));
expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2);
pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2);
pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t)));
if constexpr (IsSameType<T, int8_t>::value) {
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t)));
} else {
pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(T)));
}
}
template <typename T, typename TilingData>
__aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::Process() {
if (this->blockIdx < this->srcToDstTilingData->needCoreNum) {
AssistInit();
this->outTmpLocal = copyOutZeroQueue.DeQue<T>();
currentLoopRows = perLoopRows;
for (int64_t loop = 0; loop < this->rowLoops; loop++) {
if (loop == this->rowLoops - 1) {
currentLoopRows = lastLoopRows;
}
CopyIn(loop);
CopyOut(loop);
}
CopyOutRemain();
}
this->SyncAll();
}
} // namespace MoeInitRoutingQuantV2
#endif // INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H

View File

@@ -0,0 +1,66 @@
#pragma once
namespace optiling {
struct AiCoreParams {
uint64_t ubSize;
uint64_t blockDim;
uint64_t aicNum;
uint64_t l1Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
};
class TilingBaseClass {
public:
bool DoTiling(
int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0,
int64_t aivCoreNum, int64_t ubSizePlatForm)
{
bool ret = GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode, expertTokensCountOrCumsumFlag,
expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0);
if (!ret){
return ret;
}
ret = GetPlatformInfo(aivCoreNum, ubSizePlatForm);
if (!ret){
return ret;
}
ret = DoOpTiling();
if (!ret){
return ret;
}
ret = GetWorkspaceSize();
if (!ret){
return ret;
}
ret = PostTiling();
if (!ret){
return ret;
}
tilingKey_ = GetTilingKey();
return true;
}
//protected:
virtual bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) = 0;
virtual bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity,
int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag,
bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) = 0;
virtual bool DoOpTiling() = 0;
virtual bool GetWorkspaceSize() = 0;
virtual bool PostTiling() = 0;
virtual uint64_t GetTilingKey() const = 0;
//protected:
uint32_t blockDim_{0};
uint64_t workspaceSize_{0};
uint64_t tilingKey_{0};
AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0};
};
}

View File

@@ -0,0 +1,376 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file moe_token_unpermute.h
* \brief
*/
#ifndef MOE_TOKEN_UNPERMUTE
#define MOE_TOKEN_UNPERMUTE
#include "kernel_operator.h"
#include "moe_token_unpermute_tiling.h"
using namespace AscendC;
template <typename T1, typename T2, typename T3, bool PROBS> class KernelMoeTokenUnpermute {
public:
__aicore__ inline KernelMoeTokenUnpermute()
{
}
__aicore__ inline void Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs,
GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data);
__aicore__ inline void Process();
protected:
__aicore__ inline void CalMultiOutToken(const int64_t out_offset, const int64_t out_tokens_number);
__aicore__ inline void CalSingleOutToken(const int64_t start_token, const int64_t out_token_idx);
__aicore__ inline void CalPartOutToken(const int64_t start_token, const int64_t h_index, const int64_t h_length,
const int64_t out_token_index);
__aicore__ inline void CopyTokenIn(const T2 in_token_index, const int64_t h_index, const int64_t h_length);
__aicore__ inline void CalFirstToken(const float prob_value, const int64_t h_length);
__aicore__ inline void CalToken(const float prob_value, const int64_t h_length);
__aicore__ inline void CopyOut(const int64_t out_token_index, const int64_t h_index, const int64_t h_length);
TPipe pipe;
TQue<QuePosition::VECIN, 1> tokens_inque, indices_inque, probs_inque;
TBuf<TPosition::VECCALC> temp_buffer0, temp_buffer1, temp_buffer2;
TQue<QuePosition::VECOUT, 1> outque;
GlobalTensor<T1> tokensGM, outGM;
GlobalTensor<T2> indicesGM;
GlobalTensor<T3> probsGM;
LocalTensor<T2> indicesLocal;
LocalTensor<float> token_tensor0, token_tensor1, probs_tensor;
DataCopyPadExtParams<T1> extParams1{false, 0, 0, 0};
DataCopyPadExtParams<T2> extParams2{false, 0, 0, 0};
DataCopyPadExtParams<T3> extParams3{false, 0, 0, 0};
DataCopyExtParams copyParams{1, 0, 0, 0, 0};
constexpr static uint32_t BLOCK_SIZE = 32;
constexpr static uint32_t ALIGN_512 = 512;
int64_t hidden_size;
int64_t top_k;
int64_t num_out_tokens;
int64_t hidden_splited_length;
int64_t hidden_splited_num;
int64_t hidden_splited_remain;
int64_t tokens_core_length;
int64_t tokens_core_remain;
int64_t tokens_splited_length;
int64_t tokens_splited_num;
int64_t tokens_splited_remain;
int32_t blockIdx;
int32_t blockNum;
};
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void
KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs,
GM_ADDR unpermuted_tokens,
const MoeTokenUnpermuteTilingData *__restrict tiling_data)
{
this->blockIdx = get_block_idx() + get_subblockid() * get_block_num();
this->blockNum = get_block_num() * get_subblockdim();
if (blockIdx >= blockNum) {
return;
}
ASSERT(blockNum != 0 && "block dim can not be zero!");
// row_input
this->hidden_size = tiling_data->hidden_size;
this->top_k = tiling_data->top_k;
this->num_out_tokens = tiling_data->num_out_tokens;
// hidden_tiling
this->hidden_splited_length = tiling_data->hidden_splited_length;
this->hidden_splited_num = tiling_data->hidden_splited_num;
this->hidden_splited_remain = tiling_data->hidden_splited_remain;
// token_tiling
this->tokens_core_length = tiling_data->tokens_core_length;
this->tokens_core_remain = tiling_data->tokens_core_remain;
this->tokens_splited_length = tiling_data->tokens_splited_length;
this->tokens_splited_num = tiling_data->tokens_splited_num;
this->tokens_splited_remain = tiling_data->tokens_splited_remain;
// 处理token_by_core尾块
if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) {
this->tokens_core_length += 1;
this->tokens_splited_remain += 1;
}
int64_t hidden_splited_length_align512 = (this->hidden_splited_length + ALIGN_512 - 1) & ~(ALIGN_512 - 1);
int64_t block_length = this->tokens_core_length * this->top_k;
int64_t block_splited_length = this->tokens_splited_length * this->top_k;
int64_t block_offset;
if (this->tokens_core_remain > 0) {
if (blockIdx < this->tokens_core_remain) {
block_offset = block_length * blockIdx;
} else {
block_offset = (block_length + this->top_k) * this->tokens_core_remain +
block_length * (blockIdx - this->tokens_core_remain);
}
} else {
block_offset = block_length * blockIdx;
}
this->tokensGM.SetGlobalBuffer((__gm__ T1 *)permuted_tokens);
this->indicesGM.SetGlobalBuffer((__gm__ T2 *)sorted_indices + block_offset, block_length);
int64_t out_block_offset;
if (this->tokens_core_remain > 0) {
if (blockIdx < this->tokens_core_remain) {
out_block_offset = this->tokens_core_length * blockIdx * hidden_size;
} else {
out_block_offset = (this->tokens_core_length + 1) * this->tokens_core_remain +
this->tokens_core_length * (blockIdx - this->tokens_core_remain);
out_block_offset *= this->hidden_size;
}
} else {
out_block_offset = this->tokens_core_length * blockIdx * hidden_size;
}
this->outGM.SetGlobalBuffer((__gm__ T1 *)unpermuted_tokens + out_block_offset,
this->tokens_core_length * this->hidden_size);
this->pipe.InitBuffer(tokens_inque, tiling_data->buffer_num, hidden_splited_length_align512 * sizeof(T1));
this->pipe.InitBuffer(indices_inque, 1, block_splited_length * (sizeof(T2)));
this->pipe.InitBuffer(outque, 1, hidden_splited_length_align512 * sizeof(T1));
if constexpr (!IsSameType<T1, float>::value) {
this->pipe.InitBuffer(temp_buffer0, hidden_splited_length_align512 * sizeof(float) + 256);
this->pipe.InitBuffer(temp_buffer1, hidden_splited_length_align512 * sizeof(float));
this->token_tensor0 = this->temp_buffer0.template Get<float>();
this->token_tensor1 = this->temp_buffer1.template Get<float>();
}
if constexpr (PROBS) {
this->probsGM.SetGlobalBuffer((__gm__ T3 *)probs + block_offset, block_length);
this->pipe.InitBuffer(probs_inque, 1, block_splited_length * (sizeof(T3)));
if constexpr (!IsSameType<T3, float>::value) {
this->pipe.InitBuffer(temp_buffer2, block_splited_length * sizeof(float));
this->probs_tensor = this->temp_buffer2.template Get<float>();
}
}
};
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Process()
{
if (blockIdx >= blockNum) {
return;
}
for (int64_t i = 0; i < this->tokens_splited_num; ++i) {
CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length);
}
// 处理tokens_num不能均匀分核数的尾块
if (this->tokens_splited_remain > 0) {
CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain);
}
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalMultiOutToken(const int64_t out_offset,
const int64_t out_tokens_number)
{
this->indicesLocal = this->indices_inque.template AllocTensor<T2>();
int64_t in_offset = out_offset * this->top_k;
this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T2);
DataCopyPad(this->indicesLocal, this->indicesGM[in_offset], this->copyParams, this->extParams2);
this->indices_inque.template EnQue(this->indicesLocal);
if constexpr (PROBS) {
LocalTensor<T3> temp_probs_tensor = this->probs_inque.template AllocTensor<T3>();
this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T3);
DataCopyPad(temp_probs_tensor, this->probsGM[in_offset], this->copyParams, this->extParams3);
this->probs_inque.template EnQue(temp_probs_tensor);
temp_probs_tensor = this->probs_inque.template DeQue<T3>();
if constexpr (!IsSameType<T3, float>::value) {
Cast(this->probs_tensor, temp_probs_tensor, RoundMode::CAST_NONE, out_tokens_number * this->top_k);
this->probs_inque.FreeTensor(temp_probs_tensor);
PipeBarrier<PIPE_V>();
} else {
this->probs_tensor = temp_probs_tensor;
}
}
this->indicesLocal = this->indices_inque.template DeQue<T2>();
for (int64_t out_token_idx = 0; out_token_idx < out_tokens_number; ++out_token_idx) {
CalSingleOutToken(out_token_idx * this->top_k, out_offset + out_token_idx);
}
// Free Tensor
this->indices_inque.FreeTensor(this->indicesLocal);
if constexpr (PROBS && IsSameType<T3, float>::value) {
this->probs_inque.FreeTensor(this->probs_tensor);
}
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalSingleOutToken(const int64_t start_token,
const int64_t out_token_idx)
{
for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) {
CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx);
}
// 一次不能完整容纳完整的hidden_size, 处理尾块
if (this->hidden_splited_remain > 0) {
CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx);
}
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void
KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalPartOutToken(const int64_t start_token, const int64_t h_index,
const int64_t h_length, const int64_t out_token_index)
{
if constexpr (IsSameType<T1, float>::value) {
this->token_tensor0 = this->outque.template AllocTensor<T1>();
}
int64_t end_token = start_token + this->top_k;
T2 cal_token_idx = this->indicesLocal.GetValue(start_token);
// 处理第一个Token数据
if (cal_token_idx < this->num_out_tokens) {
float probsValue = 0;
if constexpr (PROBS) {
probsValue = this->probs_tensor.GetValue(start_token);
}
CopyTokenIn(cal_token_idx, h_index, h_length);
PipeBarrier<PIPE_V>();
CalFirstToken(probsValue, h_length);
} else {
PipeBarrier<PIPE_V>();
Duplicate(this->token_tensor0, static_cast<float>(0), h_length);
}
// 处理剩余的Token数据
for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) {
cal_token_idx = this->indicesLocal.GetValue(token_index);
if (cal_token_idx < this->num_out_tokens) {
float probsValue = 0;
if constexpr (PROBS) {
probsValue = this->probs_tensor.GetValue(token_index);
}
CopyTokenIn(cal_token_idx, h_index, h_length);
PipeBarrier<PIPE_V>();
CalToken(probsValue, h_length);
}
}
// 输出计算结果
CopyOut(out_token_index, h_index, h_length);
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CopyTokenIn(const T2 in_token_index,
const int64_t h_index,
const int64_t h_length)
{
LocalTensor<T1> tokensLocal = this->tokens_inque.template AllocTensor<T1>();
int64_t offset = in_token_index * this->hidden_size + h_index * this->hidden_splited_length;
if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) {
DataCopy(tokensLocal, this->tokensGM[offset], h_length);
} else {
this->copyParams.blockLen = h_length * sizeof(T1);
DataCopyPad(tokensLocal, this->tokensGM[offset], this->copyParams, this->extParams1);
}
this->tokens_inque.template EnQue(tokensLocal);
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalFirstToken(const float prob_value,
const int64_t h_length)
{
LocalTensor<T1> tokensLocal = this->tokens_inque.template DeQue<T1>();
if constexpr (!IsSameType<T1, float>::value) {
Cast(this->token_tensor0, tokensLocal, RoundMode::CAST_NONE, h_length);
} else {
uint64_t byteAlign32 = (h_length * sizeof(float) + BLOCK_SIZE - 1) & ~(BLOCK_SIZE - 1);
DataCopy(this->token_tensor0, tokensLocal, byteAlign32 / sizeof(float));
}
this->tokens_inque.FreeTensor(tokensLocal);
if constexpr (PROBS) {
PipeBarrier<PIPE_V>();
Muls(this->token_tensor0, this->token_tensor0, prob_value, h_length);
}
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalToken(const float prob_value,
const int64_t h_length)
{
LocalTensor<T1> tokensLocal = this->tokens_inque.template DeQue<T1>();
if constexpr (!IsSameType<T1, float>::value) {
Cast(this->token_tensor1, tokensLocal, RoundMode::CAST_NONE, h_length);
this->tokens_inque.FreeTensor(tokensLocal);
if constexpr (PROBS) {
PipeBarrier<PIPE_V>();
Muls(this->token_tensor1, this->token_tensor1, prob_value, h_length);
}
PipeBarrier<PIPE_V>();
Add(this->token_tensor0, this->token_tensor0, this->token_tensor1, h_length);
} else {
if constexpr (PROBS) {
Muls(tokensLocal, tokensLocal, prob_value, h_length);
PipeBarrier<PIPE_V>();
}
Add(this->token_tensor0, this->token_tensor0, tokensLocal, h_length);
this->tokens_inque.FreeTensor(tokensLocal);
}
}
template <typename T1, typename T2, typename T3, bool PROBS>
__aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CopyOut(const int64_t out_token_index,
const int64_t h_index,
const int64_t h_length)
{
LocalTensor<T1> temp_out_tensors;
if constexpr (!IsSameType<T1, float>::value) {
temp_out_tensors = this->outque.template AllocTensor<T1>();
PipeBarrier<PIPE_V>();
Cast(temp_out_tensors, this->token_tensor0, RoundMode::CAST_RINT, h_length);
} else {
temp_out_tensors = this->token_tensor0;
}
this->outque.template EnQue<T1>(temp_out_tensors);
temp_out_tensors = this->outque.template DeQue<T1>();
int64_t offset = out_token_index * this->hidden_size + h_index * this->hidden_splited_length;
if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) {
DataCopy(this->outGM[offset], temp_out_tensors, h_length);
} else {
this->copyParams.blockLen = h_length * sizeof(T1);
DataCopyPad(this->outGM[offset], temp_out_tensors, this->copyParams);
}
this->outque.FreeTensor(temp_out_tensors);
}
#endif // MOE_TOKEN_UNPERMUTE

View File

@@ -0,0 +1,38 @@
#ifndef MOE_TOKEN_UNPERMUTE_TILING
#define MOE_TOKEN_UNPERMUTE_TILING
struct MoeTokenUnpermuteTilingData {
int64_t hidden_size;
int64_t top_k;
int64_t num_out_tokens;
int64_t hidden_splited_length;
int64_t hidden_splited_num;
int64_t hidden_splited_remain;
int64_t tokens_core_length;
int64_t tokens_core_remain;
int64_t tokens_splited_length;
int64_t tokens_splited_num;
int64_t tokens_splited_remain;
int64_t buffer_num;
};
__forceinline__ [host, aicore] void
MoeTokenUnpermuteTiling(int32_t m, int32_t n, int32_t topK, MoeTokenUnpermuteTilingData &tilingData, uint32_t coreNum)
{
#define I64(x) static_cast<int64_t>(x)
tilingData.hidden_size = I64(n);
tilingData.top_k = I64(topK);
tilingData.num_out_tokens = I64(m);
tilingData.hidden_splited_length = tilingData.hidden_size;
tilingData.hidden_splited_num = 1;
tilingData.hidden_splited_remain = 0;
uint32_t outTokens = m / topK;
tilingData.tokens_core_length = I64(outTokens / coreNum);
tilingData.tokens_core_remain = I64(outTokens % coreNum);
tilingData.tokens_splited_length = I64(min(tilingData.tokens_core_length, 600));
tilingData.tokens_splited_num = I64(tilingData.tokens_core_length / tilingData.tokens_splited_length);
tilingData.tokens_splited_remain = I64(tilingData.tokens_core_length % tilingData.tokens_splited_length);
tilingData.buffer_num = 4;
}
#endif

View File

@@ -0,0 +1,207 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/epilogue/block/block_epilogue.hpp"
namespace Catlass::Epilogue::Block {
// float scale, dequant per expert
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequant<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequant<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(
std::is_same_v<ElementC, half> && (std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
"The element type template parameters of BlockEpilogue are wrong"
);
static_assert(
std::is_same_v<LayoutC, layout::RowMajor> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong"
);
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 4096;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
constexpr int32_t blockN = 12000;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += blockN * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += blockN * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += blockN * sizeof(float);
}
}
CATLASS_DEVICE
void Finalize()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
MatrixCoord const &shapeC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
AscendC::GlobalTensor<ElementD> const &gmD
)
{
uint32_t blockM = shapeC.row();
uint32_t blockN = shapeC.column();
uint32_t tileLoops = blockM;
for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) {
auto gmTileC = gmC[loopIdx * blockN];
auto &ubC = ubCList[ubListId];
auto &ubCFp32 = ubCFp32List[ubListId];
auto &ubMul = ubMulList[ubListId];
auto &ubD = ubDList[ubListId];
auto gmTileD = gmD[loopIdx * blockN];
LayoutC layoutUbC{1, blockN};
// 把C从GM workspace搬到UB
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
//在UB上做把C cast成FP32
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
// 获取pertoken scale值gmPerTokenScale的第loopIdx行
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
// pertoken scale值与FP32的C做Muls乘法
AscendC::PipeBarrier<PIPE_V>();
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
AscendC::PipeBarrier<PIPE_V>();
// 将muls结果转回fp16/bf16
LayoutD layoutUbD{1, blockN};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
AscendC::LocalTensor<float> ubMulList[UB_STAGES];
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP

View File

@@ -0,0 +1,316 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
namespace Catlass::Epilogue::Block {
// float scale, dequant per expert
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileElemWiseMuls_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileElemWiseMuls_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(
std::is_same_v<ElementC, half> && (std::is_same_v<ElementD, float> || std::is_same_v<ElementD, int8_t>),
"The element type template parameters of BlockEpilogue are wrong"
);
static_assert(
std::is_same_v<LayoutC, layout::RowMajor> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong"
);
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm<ArchTag, Gemm::GemmType<ElementPerTokenScale, LayoutPerTokenScale>>;
struct Params {
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_
) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_), layoutD(layoutD_) {}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
constexpr uint32_t blockN = 4096;
constexpr uint32_t ChunkTileLen = blockN / 2;
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += blockN * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += blockN * sizeof(ElementD);
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += blockN * sizeof(float);
ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += ChunkTileLen * sizeof(float);
ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += ChunkTileLen * sizeof(float);
ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += HalfChunkTileLen * sizeof(float);
ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<int32_t>();
ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<half>();
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
}
CATLASS_DEVICE
void Finalize()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
// 每个tile就是1*7168每个block是一个expert的所有token=[group[i], 7168]
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
MatrixCoord const &shapeC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale1,
AscendC::GlobalTensor<ElementD> const &gmD,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale2,
uint32_t epilogueCoreNum = 40,
Callback &&callback = Callback{}
)
{
callback();
uint32_t blockM = shapeC.row();
uint32_t blockN = shapeC.column();
uint32_t tileLoops = blockM;
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
uint32_t subblockNum = get_block_num() * 2;
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
if (subblockIdx < moveDataCoreNum) {
return;
}
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
uint32_t perCoreData = blockM / epilogueCoreNum;
uint32_t remainderData = blockM % epilogueCoreNum;
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
uint32_t ChunkTileLen = blockN / 2;
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
auto gmTileC = gmC[loopIdx * blockN];
auto &ubC = ubCList[ubListId];
auto &ubD = ubDList[ubListId];
auto &ubCFp32 = ubCFp32List[ubListId];
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
auto &ubAbs = ubCFp32ChunkNAbsList[ubListId];
// auto &ubMax = ubCFp32ChunkNMaxList[ubListId];
auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId];
auto &ubOutputTmp = ubAbs;
auto &sharedUbTmpBuffer = ubReduceMax;
auto &ubQuantS32 = ubQuantS32List[ubListId];
auto &ubQuantF16 = ubQuantF16List[ubListId];
auto gmTileD = gmD[loopIdx * ChunkTileLen];
LayoutC layoutUbC{1, blockN};
// 把C从GM workspace搬到UB
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
// 在UB上做把C cast成FP32
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
// 获取pertoken scale值gmPerTokenScale的第loopIdx行
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
// pertoken scale值与FP32的C做Muls乘法
AscendC::PipeBarrier<PIPE_V>();
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
AscendC::PipeBarrier<PIPE_V>();
//swiglue计算过程
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
//TODO除的时候是否会对之后的数据有影响
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
//quant过程两种方式区别
AscendC::PipeBarrier<PIPE_V>();
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::ReduceMax<float>(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
//TODO两种计算方法的效率比较
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx;
ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetDeqScale(static_cast<half>(1.0));
AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
// AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
LayoutD layoutUbD{1, ChunkTileLen};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
if(tasksForIdx > 0){
LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx};
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2);
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNList[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNAbsList[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNMaxList[UB_STAGES];
AscendC::LocalTensor<int32_t> ubQuantS32List[UB_STAGES];
AscendC::LocalTensor<half> ubQuantF16List[UB_STAGES];
AscendC::LocalTensor<float> ubPerTokenScaleOutput;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyUbToGmDequantScale copyUbToGmDequantScale;
};
} // namespace Catlass::Epilogue::Block
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP

View File

@@ -0,0 +1,502 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/helper.hpp"
#include "dispatch_policy_custom.hpp"
namespace Catlass::Gemm::Block {
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
AscendC::SetFlag<event>(eventID);
AscendC::WaitFlag<event>(eventID);
}
template <
uint32_t PRELOAD_STAGES_,
uint32_t L1_STAGES_,
uint32_t L0A_STAGES_,
uint32_t L0B_STAGES_,
uint32_t L0C_STAGES_,
bool ENABLE_UNIT_FLAG_,
bool ENABLE_SHUFFLE_K_,
class L1TileShape_,
class L0TileShape_,
class AType_,
class BType_,
class CType_,
class BiasType_,
class TileCopy_,
class TileMmad_
>
struct BlockMmad <
MmadAtlasA2PreloadAsyncFixpipe<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
>,
L1TileShape_,
L0TileShape_,
AType_,
BType_,
CType_,
BiasType_,
TileCopy_,
TileMmad_
> {
public:
// Type Aliases
using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
>;
using ArchTag = typename DispatchPolicy::ArchTag;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using ElementA = typename AType_::Element;
using LayoutA = typename AType_::Layout;
using ElementB = typename BType_::Element;
using LayoutB = typename BType_::Layout;
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using TileMmad = TileMmad_;
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
using CopyGmToL1S = Gemm::Tile::CopyGmToL1<ArchTag, Gemm::GemmType<uint64_t, layout::VectorLayout>>;
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
using ElementAccumulator =
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
using CopyL0CToGm = typename std::conditional<
std::is_same_v<ElementA, int8_t>,
Gemm::Tile::CopyL0CToGm<ArchTag, ElementAccumulator, CType_, Gemm::Tile::ScaleGranularity::PER_CHANNEL>,
typename TileCopy_::CopyL0CToGm
>::type;
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
using LayoutCInL0 = layout::zN;
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES;
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
// L1 tile size
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t);
// L0 tile size
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
// Check LayoutC
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
// Check L1TileShape
static_assert(
(std::is_same_v<ElementA, int8_t>
? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE
: (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE),
"L1TileShape exceeding the L1 space for the given data type"
);
// Check L0TileShape
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(
L1TileShape::M, L1TileShape::K);
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
InitL0C(resource);
}
CATLASS_DEVICE
~BlockMmad()
{
SynchronizeBlock();
for (uint32_t i = 0; i < L1_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
}
CATLASS_DEVICE
void operator()(
AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
AscendC::GlobalTensor<uint64_t> const &gmBlockS, layout::VectorLayout const &layoutScale,
GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0
)
{
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
uint32_t startTileIdx = 0;
if constexpr (ENABLE_SHUFFLE_K) {
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
}
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ?
(startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount);
uint32_t kActual = (kTileIdx < kTileCount - 1) ?
L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
// Emission load instruction from GM to L1
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
// Load first matrix A tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
// Load first matrix B tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
if (preloadCount == PRELOAD_STAGES) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
}
// Store the current load status
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ?
(l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
l1TileMmadParams.l1ListId = l1ListId;
l1TileMmadParams.mRound = mRound;
l1TileMmadParams.nRound = nRound;
l1TileMmadParams.kActual = kActual;
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
l1TileMmadParams.flag = flag;
if (kLoopIdx == kTileCount - 1) {
l1TileMmadParams.gmBlockC = gmBlockC;
l1TileMmadParams.gmBlockS = gmBlockS;
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
l1TileMmadParams.layoutScale = layoutScale;
l1TileMmadParams.syncLoopIdx = syncLoopIdx;
}
if (preloadCount < PRELOAD_STAGES) {
++preloadCount;
} else {
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
}
l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0;
}
}
CATLASS_DEVICE
void SynchronizeBlock()
{
while (preloadCount > 0) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
--preloadCount;
}
}
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 8 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
private:
struct L1TileMmadParams {
uint32_t l1ListId;
uint32_t mRound;
uint32_t nRound;
uint32_t kActual;
bool isKLoopFirst;
bool isKLoopLast;
AscendC::GlobalTensor<ElementC> gmBlockC;
AscendC::GlobalTensor<uint64_t> gmBlockS;
LayoutC layoutCInGm;
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
CATLASS_DEVICE
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
{
uint32_t l1AOffset = l1BufAddrStart;
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES;
for (uint32_t i = 0; i < L1_STAGES; ++i) {
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
l1AEventList[i] = i;
l1BEventList[i] = i + L1_STAGES;
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
}
CATLASS_DEVICE
void InitL0A(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
l0AEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
}
CATLASS_DEVICE
void InitL0B(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
l0BEventList[i] = i + L0A_STAGES;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
}
CATLASS_DEVICE
void InitL0C(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
l0CEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
}
CATLASS_DEVICE
void L1TileMmad(L1TileMmadParams const &params)
{
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
auto &l1ATensor = l1ATensorList[params.l1ListId];
auto &l1BTensor = l1BTensorList[params.l1ListId];
auto &l0CTensor = l0CTensorList[l0CListId];
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
if constexpr (!ENABLE_UNIT_FLAG) {
if (params.isKLoopFirst) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
}
}
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ?
L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ?
L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
auto &l0ATile = l0ATensorList[l0AListId];
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
if ((mPartIdx == 0) && (kPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1ListId]);
}
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1ListId]);
}
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ?
L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
auto &l0BTile = l0BTensorList[l0BListId];
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
if ((kPartIdx == 0) && (nPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1ListId]);
}
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1ListId]);
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
// If the unit flag is enabled, the unit flag is set according to the calculation progress
uint8_t unitFlag = 0b00;
if constexpr (ENABLE_UNIT_FLAG) {
if (params.isKLoopLast &&
(mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
unitFlag = 0b11;
} else {
unitFlag = 0b10;
}
}
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
}
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
}
}
if (params.isKLoopLast) {
auto layoutCInGm = params.layoutCInGm;
if constexpr (std::is_same_v<ElementA, int8_t>) {
auto layoutScale = params.layoutScale;
auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1)));
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS);
AscendC::SetFlag<AscendC::HardEvent::MTE2_FIX>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_FIX>(0);
}
if constexpr (!ENABLE_UNIT_FLAG) {
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
if constexpr (std::is_same_v<ElementA, int8_t>) {
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0);
} else {
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
}
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
} else {
if constexpr (std::is_same_v<ElementA, int8_t>) {
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11);
} else {
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
}
}
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
Finalize(params.syncLoopIdx, params.flag);
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
uint32_t l1ListId{0};
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
int32_t l0AEventList[L0A_STAGES];
uint32_t l0AListId{0};
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
int32_t l0BEventList[L0B_STAGES];
uint32_t l0BListId{0};
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
int32_t l0CEventList[L0C_STAGES_];
uint32_t l0CListId{0};
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
uint32_t l1TileMmadParamsId{0};
uint32_t preloadCount{0};
TileMmad tileMmad;
CopyGmToL1A copyGmToL1A;
CopyGmToL1B copyGmToL1B;
CopyGmToL1S copyGmToL1S;
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -0,0 +1,6 @@
#ifndef CONST_ARGS_HPP
#define CONST_ARGS_HPP
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
#endif

View File

@@ -0,0 +1,40 @@
#ifndef COPY_GM_TO_L1_CUSTOM_HPP
#define COPY_GM_TO_L1_CUSTOM_HPP
namespace Catlass::Gemm::Tile {
/// Partial specialization for nZ in and nZ out.
template <
class ArchTag,
class Element
>
struct CopyGmToL1<ArchTag, Gemm::GemmType<Element, layout::VectorLayout>> {
using LayoutDst = layout::VectorLayout;
using LayoutSrc = layout::VectorLayout;
static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4
// Mehtods
CATLASS_DEVICE
CopyGmToL1() {};
CATLASS_DEVICE
void operator()(
AscendC::LocalTensor<Element> const &dstTensor,
AscendC::GlobalTensor<Element> const &srcTensor,
LayoutDst const &layoutDst, LayoutSrc const &layoutSrc)
{
uint32_t blockCount = 1;
uint32_t blockLen = CeilDiv<ELE_NUM_PER_C0>(layoutSrc.shape(0));
AscendC::DataCopyParams repeatParams;
repeatParams.blockCount = blockCount;
repeatParams.blockLen = blockLen;
repeatParams.srcStride = 0;
repeatParams.dstStride = 0;
AscendC::DataCopy(dstTensor, srcTensor, repeatParams);
}
};
}
#endif // COPY_GM_TO_L1_CUSTOM_HPP

View File

@@ -0,0 +1,47 @@
#ifndef COPY_L0C_TO_GM_CUSTOM_HPP
#define COPY_L0C_TO_GM_CUSTOM_HPP
namespace Catlass::Gemm::Tile {
template <
class ElementAccumulator_,
class ElementDst_,
bool ReluEnable_
>
struct CopyL0CToGm<Catlass::Arch::AtlasA2,
ElementAccumulator_,
Gemm::GemmType<ElementDst_, layout::RowMajor>,
ScaleGranularity::PER_CHANNEL,
ReluEnable_>
{
using ArchTag = Catlass::Arch::AtlasA2;
using ElementDst = ElementDst_;
using ElementSrc = ElementAccumulator_;
using LayoutSrc = Catlass::layout::zN;
using LayoutDst = Catlass::layout::RowMajor;
static constexpr auto quantPre = CopyL0CToGmQuantMode<ArchTag, ElementSrc, ElementDst,
ScaleGranularity::PER_CHANNEL>::VALUE;
static constexpr auto reluEn = ReluEnable_;
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<ElementDst> const &dst, AscendC::LocalTensor<ElementSrc> const &src, AscendC::LocalTensor<uint64_t> cbufWorkspace,
LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0)
{
AscendC::FixpipeParamsV220 intriParams;
// Fixpipe layout information
intriParams.nSize = dstLayout.shape(1);
intriParams.mSize = dstLayout.shape(0);
intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0);
intriParams.dstStride = dstLayout.stride(0);
// Fixpipe auxiliary arguments
intriParams.quantPre = quantPre;
intriParams.reluEn = reluEn;
intriParams.unitFlag = unitFlag;
// Call AscendC Fixpipe
AscendC::Fixpipe<ElementDst, ElementSrc, AscendC::CFG_ROW_MAJOR>(dst, src, cbufWorkspace, intriParams);
}
};
}
#endif // COPY_L0C_TO_GM_CUSTOM_HPP

View File

@@ -0,0 +1,47 @@
#ifndef DISPATH_POLICY_CUSTOM_HPP
#define DISPATH_POLICY_CUSTOM_HPP
namespace Catlass::Gemm {
template <bool ENABLE_UNIT_FLAG_ = false, bool ENABLE_SHUFFLE_K_ = false>
struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 {
static constexpr uint32_t STAGES = 2;
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
};
template <uint32_t PRELOAD_STAGES_, uint32_t L1_STAGES_, uint32_t L0A_STAGES_, uint32_t L0B_STAGES_,
uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
struct MmadAtlasA2PreloadAsyncFixpipe :
public MmadAtlasA2PreloadAsync<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
> {
};
}
namespace Catlass::Epilogue {
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2UnQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
}
#endif // DISPATH_POLICY_CUSTOM_HPP

View File

@@ -0,0 +1,178 @@
#ifndef SYNC_UTIL_HPP
#define SYNC_UTIL_HPP
#include "kernel_operator.h"
#include "const_args.hpp"
#include "moe_distribute_base.h"
#ifndef HCCL_COMM
#include "shmem_api.h"
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
*((__gm__ T *)addr) = val;
}
template<typename T>
FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
DataCacheCleanAndInvalid<uint8_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) {
return *sig_addr;
}
// in case when peer pe enters next barrier
if (*sig_addr == cmp_val + 1) {
return *sig_addr;
}
} while (true);
// never reach
return -1;
}
constexpr int32_t MAX_RANK_SIZE = 32;
class HcclShmem {
public:
#ifdef HCCL_COMM // hccl需要初始化hccl context
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
size_t m_segmentSize;
int32_t m_rank;
int32_t m_rankSize;
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
for (int i = 0; i < m_rankSize; i++) {
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
}
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const { // 无参数返回本地peermem
#ifdef HCCL_COMM
return m_ptrArray[m_rank];
#else
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const { // 带index参数返回远端peermem首地址
#ifdef HCCL_COMM
return m_ptrArray[index];
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
if (offset < 0 || offset >= m_segmentSize) {
return nullptr;
}
if (rankId < 0 || rankId >= m_rankSize) {
return nullptr;
}
return m_ptrArray[rankId] + offset;
#else
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
#endif
}
// FORCE_INLINE_AICORE
// GM_ADDR operator () (GM_ADDR ptr, int32_t index) const { // shmem_ptr相同用法
// #ifdef HCCL_COMM
// size_t offset = ptr - m_ptrArray[m_rank];
// if (offset < 0 || offset >= m_segmentSize) {
// return nullptr;
// }
// if (index < 0 || index >= m_rankSize) {
// return nullptr;
// }
// return m_ptrArray[index] + offset;
// #else
// return shmem_ptr(ptr, index);
// #endif
// }
FORCE_INLINE_AICORE
~HcclShmem() {
}
FORCE_INLINE_AICORE
void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
__gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset;
__gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048;
int count = gm_load(sync_base) + 1;
int vec_id = AscendC::GetBlockIdx();
int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation();
for(int i = vec_id; i < m_rankSize; i += vec_size) {
__gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16;
gm_store(sync_remote, count);
gm_dcci((__gm__ uint8_t*)sync_remote);
auto sync_check = sync_counter + i * 16;
gm_signal_wait_until_eq_for_barrier(sync_check, count);
}
AscendC::SyncAll<true>();
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
}
};
#endif

View File

@@ -0,0 +1,20 @@
#ifndef LAYOUT_3D_HPP
#define LAYOUT_3D_HPP
#include "kernel_operator.h"
#include "catlass/catlass.hpp"
class Layout3D {
int64_t strides[2];
public:
CATLASS_DEVICE
Layout3D() {}
CATLASS_DEVICE
Layout3D(int64_t stride0, int64_t stride1) {
strides[0] = stride0;
strides[1] = stride1;
}
CATLASS_DEVICE
int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) {
return dim0 * strides[0] + dim1 * strides[1] + dim2;
}
};
#endif // LAYOUT_3D_HPP

View File

@@ -0,0 +1,25 @@
#ifndef SELECT_HELPER_HPP
#define SELECT_HELPER_HPP
#include "catlass/layout/layout.hpp"
using namespace AscendC;
using namespace Catlass;
template <typename Layout, typename ElementType, typename = void>
struct LayoutBInitializer {
CATLASS_DEVICE
static Layout create(uint32_t k, uint32_t n) {
return Layout{k, n};
}
};
template <typename Layout, typename ElementType>
struct LayoutBInitializer<Layout, ElementType,
std::enable_if_t<std::is_same_v<Layout, layout::zN>>
> {
CATLASS_DEVICE
static Layout create(uint32_t k, uint32_t n) {
return Layout::template MakeLayout<ElementType>(k, n);
}
};
#endif // SELECT_HELPER_HPP

1
csrc/third_party/catlass vendored Submodule

View File

@@ -37,59 +37,60 @@
namespace vllm_ascend {
const int64_t INT4_NUMS_IN_INT32 = 8;
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping, aclrtStream stream) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
aclrtMemcpyKind memcpy_type;
const torch::Tensor& block_mapping, aclrtStream stream)
{
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
aclrtMemcpyKind memcpy_type;
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same npu");
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
} else {
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
}
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same npu");
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
} else {
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
}
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const int64_t num_blocks = block_mapping.size(0);
const int64_t max_src_block = src.size(0);
const int64_t max_dst_block = dst.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
const int64_t num_blocks = block_mapping.size(0);
const int64_t max_src_block = src.size(0);
const int64_t max_dst_block = dst.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
src_ptr + src_offset, block_size_in_bytes,
memcpy_type, stream);
}
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
src_ptr + src_offset, block_size_in_bytes,
memcpy_type, stream);
}
}
void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
{
const c10_npu::OptionalNPUGuard npuGuard(
const c10_npu::OptionalNPUGuard npuGuard(
(!x.device().is_cpu()) ? x.device() : y.device()
);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
swap_blocks_impl(x, y, z, stream);
return;
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
swap_blocks_impl(x, y, z, stream);
return;
}
AscendType get_dtype_from_torch(at::ScalarType scalarType)
@@ -617,7 +618,33 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
});
cmd.Run();
return;
}
at::Tensor& dispatch_ffn_combine(
const at::Tensor& x,
const at::Tensor& weight1,
const at::Tensor& weight2,
const at::Tensor& expert_idx,
const at::Tensor& scale1,
const at::Tensor& scale2,
const at::Tensor& probs,
c10::string_view group,
int64_t max_output_size,
at::Tensor& out
) {
char *group_ep_ptr = const_cast<char *>(group.data());
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
x,
weight1,
weight2,
expert_idx,
scale1,
scale2,
probs,
group_ep_ptr,
max_output_size,
out);
return out;
}
at::Tensor npu_lightning_indexer(
@@ -810,4 +837,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" int sparse_mode=3) -> Tensor"
);
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
ops.def(
"dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx,"
" Tensor scale1, Tensor scale2, Tensor probs, str group,"
" int max_output_size, Tensor! out) -> Tensor"
);
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
}

View File

@@ -156,7 +156,21 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
c10::optional<c10::string_view> quant_mode)
{
return;
}
at::Tensor& dispatch_ffn_combine_meta(
const at::Tensor& x,
const at::Tensor& weight1,
const at::Tensor& weight2,
const at::Tensor& expert_idx,
const at::Tensor& scale1,
const at::Tensor& scale2,
const at::Tensor& probs,
c10::string_view group,
int64_t max_output_size,
at::Tensor& out
) {
return out;
}
at::Tensor npu_lightning_indexer_meta(
@@ -244,5 +258,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta);
// Sparse flash attention
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
// MoE dispatch-ffn-combine
ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta);
}
}

View File

@@ -0,0 +1,168 @@
import random
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch.distributed.distributed_c10d import _get_default_group
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
class TestDisptachFFNCombine:
def __init__(self, rank, world_size, port):
self.rank = rank
self.world_size = world_size
self.master_ip = "127.0.0.1"
self.port = port
def get_hcomm(self, comm_group):
hcomm_info = None
if torch.__version__ > "2.0.1":
hcomm_info = comm_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.rank)
else:
hcomm_info = comm_group.get_hccl_comm_name(self.rank)
return hcomm_info
def setup_ep_tp(
self,
rank,
tp_size,
ep_size,
backend_type,
ep_ranks_list=None,
tp_ranks_list=None,
):
for i in range(tp_size):
if ep_ranks_list:
ep_ranks = ep_ranks_list[i]
else:
ep_ranks = [x + ep_size * i for x in range(ep_size)]
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
if rank in ep_ranks:
ep_group_tmp = ep_group
for i in range(ep_size):
if tp_ranks_list:
tp_ranks = tp_ranks_list[i]
else:
tp_ranks = [x * ep_size + i for x in range(tp_size)]
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
if rank in tp_ranks:
tp_group_tmp = tp_group
return ep_group_tmp, tp_group_tmp
def generate_hcom(self):
torch_npu.npu.set_device(self.rank)
dist.init_process_group(
backend="hccl",
rank=self.rank,
world_size=self.world_size,
init_method=f"tcp://127.0.0.1:{self.port}",
)
ep_size = 0
tp_size = self.world_size
hcomm_info_dist = {
"default_pg_info": None,
"ep_hcomm_info": None,
"group_ep": None,
"tp_hcomm_info": None,
"group_tp": None,
}
if ep_size and tp_size:
group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size,
"hccl", None, None)
hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep)
hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp)
hcomm_info_dist["group_ep"] = group_ep
hcomm_info_dist["group_tp"] = group_tp
else:
if dist.is_available():
default_pg = _get_default_group()
hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg)
hcomm_info = hcomm_info_dist["default_pg_info"]
self.hcomm_info = hcomm_info
def run_npu_out(self) -> bool:
torch_npu.npu.set_device(self.rank)
m = 2 # token-num 32
k = 4 # hidden_size 7168
n = 4 # mid-hidden-size 4096
topk = 2
e = 2 # expert-num-per-rank 16
k2 = n // 2
n2 = k
torch_npu.npu.config.allow_internal_format = True
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
weight1 = self.generate_random_tensor((e, k, n),
dtype=torch.int8).npu()
weight1 = torch_npu.npu_format_cast(weight1, 29)
weight2 = self.generate_random_tensor((e, k2, n2),
dtype=torch.int8).npu()
weight2 = torch_npu.npu_format_cast(weight2, 29)
expert_idx = torch.randint(0,
self.world_size * e, (m, topk),
dtype=torch.int32).npu()
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
torch.ops._C_ascend.dispatch_ffn_combine(
x=x,
weight1=weight1,
weight2=weight2,
expert_idx=expert_idx,
scale1=scale1,
scale2=scale2,
probs=probs,
group=self.hcomm_info,
max_output_size=512,
out=out,
)
return True
def generate_random_tensor(self, size, dtype):
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
return torch.randn(size=size, dtype=dtype)
elif dtype is torch.int8:
return torch.randint(-16, 16, size=size, dtype=dtype)
elif dtype is torch.int32:
return torch.randint(-1024, 1024, size=size, dtype=dtype)
else:
raise ValueError(f"Invalid dtype: {dtype}")
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
op = TestDisptachFFNCombine(rank, world_size, port)
op.generate_hcom()
out = op.run_npu_out()
q.put(out)
@torch.inference_mode()
def test_dispatch_ffn_combine_kernel():
world_size = 2
mp.set_start_method("fork", force=True)
q = mp.SimpleQueue()
p_list = []
port = 29501 + random.randint(0, 10000)
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, port, q))
p.start()
p_list.append(p)
results = [q.get() for _ in range(world_size)]
for p in p_list:
p.join()
assert all(results)

View File

@@ -52,6 +52,7 @@ class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_ALLTOALL = 3
@contextmanager

View File

@@ -520,7 +520,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_output

View File

@@ -44,6 +44,8 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
class MoECommMethod(ABC):
@@ -243,3 +245,69 @@ class AlltoAllCommImpl(MoECommMethod):
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
class FusedAlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
return out

View File

@@ -513,6 +513,11 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=self.ep_group)
backend = self.ep_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,

View File

@@ -249,8 +249,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
} or forward_context.sp_enabled:
if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
} or forward_context.sp_enabled:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)

View File

@@ -24,6 +24,7 @@ from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
@@ -232,13 +233,15 @@ class AscendW8A8DynamicFusedMoEMethod:
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
w1=w1[0] if fused_flag else w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
w2=w2[0] if fused_flag else w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
@@ -270,6 +273,12 @@ class AscendW8A8DynamicFusedMoEMethod:
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
layer.fused_w1_scale = scale_from_float_to_int64(
layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(
layer.w2_weight_scale.data)
if self.dynamic_eplb:
layer.w13_weight_list = [
weight.clone()
@@ -292,3 +301,11 @@ class AscendW8A8DynamicFusedMoEMethod:
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
torch.npu.empty_cache()
def scale_from_float_to_int64(scale):
import numpy as np
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
dtype=np.int32).astype(np.int64)).to(scale.device)
return scale

View File

@@ -911,6 +911,9 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
"dp": {
"hccl_buffer_size": calculate_dp_buffer_size()
},
"ep": {
"hccl_buffer_size": calculate_ep_buffer_size()
},
}
return hccl_config_map.get(group_name, get_default_buffer_config())
@@ -932,6 +935,30 @@ def calculate_dp_buffer_size() -> int:
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
def calculate_ep_buffer_size() -> int:
"""
formula of ep buffer size:
batch_size * hidden_size * topk * 4
"""
ep_buffer_size = _DEFAULT_BUFFER_SIZE
try:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
hf_config = vllm_config.model_config.hf_config
hidden_size = hf_config.hidden_size
topk = getattr(hf_config, "num_experts_per_token", 1)
batch_size = vllm_config.scheduler_config.max_num_batched_tokens
int8_size = torch.iinfo(torch.int8).bits // 8
bf16_size = torch.finfo(torch.bfloat16).bits // 8
ep_buffer_size = math.ceil(
(batch_size * hidden_size * topk *
(int8_size * 2 + bf16_size)) / (1024 * 1024))
except Exception:
pass
return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)
# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine.

View File

@@ -2217,8 +2217,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return None
soc_version = get_ascend_device_type()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
quant_type = getattr(
self.vllm_config.model_config.hf_config, 'moe_quantize',
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
model_type = self.vllm_config.model_config.hf_config.model_type
if not self.parallel_config.enable_expert_parallel:
@@ -2237,7 +2238,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
elif soc_version in {AscendDeviceType._910_93}:
moe_comm_type = (MoECommType.MC2
if num_tokens <= self.mc2_tokens_capacity else
MoECommType.ALLTOALL)
MoECommType.FUSED_ALLTOALL if quant_type
== "w8a8_dynamic" else MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")