[feature] add_rms_norm support bias (#5790)
### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Full Test Pass
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
39
csrc/add_rms_norm_bias/op_host/CMakeLists.txt
Normal file
39
csrc/add_rms_norm_bias/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# This file is a part of the CANN Open Software.
|
||||
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
# Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See LICENSE in the root of the software repository for the full text of the License.
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME AddRmsNormBias
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
add_rms_norm_bias_def.cpp
|
||||
)
|
||||
|
||||
# target_sources(opapi PRIVATE
|
||||
# aclnn_add_rms_norm.cpp
|
||||
# )
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
add_rms_norm_bias_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_add_rms_norm_bias.h")
|
||||
|
||||
install(FILES ${_GMM_Aclnn_header}
|
||||
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
|
||||
)
|
||||
71
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_def.cpp
Normal file
71
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_def.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_def.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class AddRmsNormBias : public OpDef {
|
||||
public:
|
||||
explicit AddRmsNormBias(const char* name) : OpDef(name)
|
||||
{
|
||||
this->Input("x1")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("x2")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("gamma")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Input("beta")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("y")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("rstd")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Output("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.AutoContiguous();
|
||||
this->Attr("epsilon").AttrType(OPTIONAL).Float(1e-6);
|
||||
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
this->AICore().AddConfig("ascend910_93");
|
||||
}
|
||||
};
|
||||
OP_ADD(AddRmsNormBias);
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_infershape.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "log/log.h"
|
||||
#include "util/shape_util.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
|
||||
static constexpr int IDX_0 = 0;
|
||||
static constexpr int IDX_1 = 1;
|
||||
static constexpr int IDX_2 = 2;
|
||||
|
||||
using namespace ge;
|
||||
using namespace Ops::Base;
|
||||
|
||||
namespace ops {
|
||||
|
||||
static ge::graphStatus InferShape4AddRmsNormBias(gert::InferShapeContext* context)
|
||||
{
|
||||
OP_LOGD(context, "Begin to do InferShape4AddRmsNormBias");
|
||||
|
||||
// get input shapes
|
||||
const gert::Shape* x1Shape = context->GetInputShape(IDX_0);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x1Shape);
|
||||
const gert::Shape* gammaShape = context->GetInputShape(IDX_2);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, gammaShape);
|
||||
// get output shapes
|
||||
gert::Shape* yShape = context->GetOutputShape(IDX_0);
|
||||
gert::Shape* rstdShape = context->GetOutputShape(IDX_1);
|
||||
gert::Shape* xShape = context->GetOutputShape(IDX_2);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, rstdShape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
||||
*yShape = *x1Shape;
|
||||
*xShape = *x1Shape;
|
||||
|
||||
size_t xDimNum = x1Shape->GetDimNum();
|
||||
size_t gammaDimNum = gammaShape->GetDimNum();
|
||||
|
||||
if (IsUnknownRank(*x1Shape) || IsUnknownRank(*gammaShape)) {
|
||||
SetUnknownRank(*rstdShape);
|
||||
OP_LOGD(context, "End to do InferShape4AddRmsNormBias with unknown rank.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
OP_CHECK_IF(
|
||||
xDimNum < gammaDimNum, OP_LOGE(context, "x dim num should not be smaller than gamma dim num."),
|
||||
return GRAPH_FAILED);
|
||||
|
||||
rstdShape->SetDimNum(xDimNum);
|
||||
for (size_t i = 0; i < xDimNum; i++) {
|
||||
if (i < xDimNum - gammaDimNum) {
|
||||
rstdShape->SetDim(i, x1Shape->GetDim(i));
|
||||
} else {
|
||||
rstdShape->SetDim(i, 1);
|
||||
}
|
||||
}
|
||||
|
||||
OP_LOGD(context, "End to do InferShape4AddRmsNormBias");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static graphStatus InferDataType4AddRmsNormBias(gert::InferDataTypeContext* context)
|
||||
{
|
||||
OP_LOGD(context, "Begin to do InferDataType4AddRmsNormBias");
|
||||
context->SetOutputDataType(IDX_0, context->GetInputDataType(IDX_0));
|
||||
context->SetOutputDataType(IDX_1, DT_FLOAT);
|
||||
context->SetOutputDataType(IDX_2, context->GetInputDataType(IDX_0));
|
||||
OP_LOGD(context, "End to do InferDataType4AddRmsNormBias");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(AddRmsNormBias).InferShape(InferShape4AddRmsNormBias).InferDataType(InferDataType4AddRmsNormBias);
|
||||
} // namespace ops
|
||||
443
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_tiling.cpp
Normal file
443
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_tiling.cpp
Normal file
@@ -0,0 +1,443 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_tiling.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "add_rms_norm_bias_tiling.h"
|
||||
#include "log/ops_log.h"
|
||||
|
||||
namespace optiling {
|
||||
constexpr uint32_t DTYPE_KEY_FP16 = 1;
|
||||
constexpr uint32_t DTYPE_KEY_FP32 = 2;
|
||||
constexpr uint32_t DTYPE_KEY_BF16 = 3;
|
||||
constexpr uint32_t UB_USED = 1024;
|
||||
constexpr uint32_t UB_FACTOR_B16 = 12288;
|
||||
constexpr uint32_t UB_FACTOR_B32 = 10240;
|
||||
constexpr uint32_t UB_FACTOR_B16_CUTD = 12096;
|
||||
constexpr uint32_t UB_FACTOR_B32_CUTD = 9696;
|
||||
|
||||
constexpr uint32_t UB_FACTOR_B32_WITH_BETA = 9216;
|
||||
constexpr uint32_t UB_FACTOR_B16_WITH_BETA = 11264;
|
||||
constexpr uint32_t UB_FACTOR_B32_CUTD_WITH_BETA = 8096;
|
||||
constexpr uint32_t UB_FACTOR_B16_CUTD_WITH_BETA = 10752;
|
||||
constexpr uint32_t SMALL_REDUCE_NUM_WITH_BETA = 1600;
|
||||
constexpr uint32_t FP32_WEIGHT_WITH_BETA = 28;
|
||||
constexpr uint32_t OTHER_WEIGHT_WITH_BETA = 20;
|
||||
constexpr size_t NUM_WITH_BETA = 4;
|
||||
|
||||
constexpr uint32_t BLOCK_ALIGN_NUM = 16;
|
||||
constexpr uint32_t FLOAT_BLOCK_ALIGN_NUM = 8;
|
||||
constexpr uint32_t SMALL_REDUCE_NUM = 2000;
|
||||
constexpr uint32_t MODE_NORMAL = 0;
|
||||
constexpr uint32_t MODE_SPLIT_D = 1;
|
||||
constexpr uint32_t MODE_MERGE_N = 2;
|
||||
constexpr uint32_t MODE_SINGLE_N = 3;
|
||||
constexpr uint32_t MODE_MULTI_N = 4;
|
||||
constexpr int32_t INPUT_X1_INDEX = 0;
|
||||
constexpr int32_t INPUT_X2_INDEX = 1;
|
||||
constexpr int32_t INPUT_GAMMA_INDEX = 2;
|
||||
constexpr int32_t INPUT_BETA_INDEX = 3;
|
||||
constexpr int32_t OUTPUT_Y_INDEX = 0;
|
||||
constexpr int32_t OUTPUT_RSTD_INDEX = 1;
|
||||
constexpr int32_t OUTPUT_X_INDEX = 2;
|
||||
constexpr size_t MAX_DIM_NUM = 8;
|
||||
constexpr size_t MIN_DIM_X = 1;
|
||||
constexpr size_t MIN_DIM_GAMMA = 1;
|
||||
constexpr size_t FP32_WEIGHT = 24;
|
||||
constexpr size_t OTHER_WEIGHT = 18;
|
||||
constexpr size_t DIV_FACTOR = 260;
|
||||
constexpr size_t FLOAT_PER_REPEAT = 64;
|
||||
constexpr size_t USE_SIZE = 256;
|
||||
constexpr size_t NUM = 2;
|
||||
constexpr int32_t TEN = 10;
|
||||
|
||||
constexpr int32_t PERFORMANC_DIM_ZERO = 0;
|
||||
constexpr int32_t PERFORMANC_DIM_ONE = 1;
|
||||
constexpr int32_t PERFORMANC_DIM_TWO = 2;
|
||||
constexpr int32_t PERFORMANC_DIM_THREE = 3;
|
||||
constexpr int32_t PERFORMANC_DIM_ONE_MAX = 512;
|
||||
constexpr int32_t PERFORMANC_DIM_TWO_MAX = 8;
|
||||
constexpr int32_t PERFORMANC_DIM_THREE_MAX = 5120;
|
||||
|
||||
platform_ascendc::SocVersion addRmsNormBiasSocVersion;
|
||||
|
||||
uint8_t getPerformanceFlag(uint32_t num_col, gert::Shape x_shape, gert::Shape gamma_shape, uint32_t xDtypeKey)
|
||||
{
|
||||
uint8_t isPerformance = 0;
|
||||
if(addRmsNormBiasSocVersion != platform_ascendc::SocVersion::ASCEND910B) {
|
||||
return isPerformance;
|
||||
}
|
||||
size_t xDimNum = x_shape.GetDimNum();
|
||||
size_t gammaDimNum = gamma_shape.GetDimNum();
|
||||
bool dimOK = ((xDimNum == PERFORMANC_DIM_TWO || xDimNum == PERFORMANC_DIM_THREE) && gammaDimNum == PERFORMANC_DIM_ONE);
|
||||
bool sizeOk = num_col <= PERFORMANC_DIM_THREE_MAX &&
|
||||
((xDimNum == PERFORMANC_DIM_TWO && x_shape.GetDim(PERFORMANC_DIM_ZERO) <= PERFORMANC_DIM_ONE_MAX) ||
|
||||
(xDimNum == PERFORMANC_DIM_THREE && x_shape.GetDim(PERFORMANC_DIM_ZERO) <= PERFORMANC_DIM_ONE_MAX && x_shape.GetDim(PERFORMANC_DIM_ONE) <= PERFORMANC_DIM_TWO_MAX));
|
||||
bool dtypeOk = (xDtypeKey == DTYPE_KEY_FP16 || xDtypeKey == DTYPE_KEY_BF16);
|
||||
if(dimOK && sizeOk && dtypeOk) {
|
||||
isPerformance = 1;
|
||||
}
|
||||
return isPerformance;
|
||||
}
|
||||
|
||||
static void SetByDtype(ge::DataType dataType, uint32_t& dtypeKey, uint32_t& dataPerBlock)
|
||||
{
|
||||
switch (dataType) {
|
||||
case ge::DT_FLOAT16:
|
||||
dtypeKey = DTYPE_KEY_FP16;
|
||||
dataPerBlock = BLOCK_ALIGN_NUM;
|
||||
break;
|
||||
case ge::DT_BF16:
|
||||
dtypeKey = DTYPE_KEY_BF16;
|
||||
dataPerBlock = BLOCK_ALIGN_NUM;
|
||||
break;
|
||||
default:
|
||||
dtypeKey = DTYPE_KEY_FP32;
|
||||
dataPerBlock = FLOAT_BLOCK_ALIGN_NUM;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool CheckInputOutputDim(const gert::TilingContext* context)
|
||||
{
|
||||
const gert::StorageShape* x1_shape = context->GetInputShape(INPUT_X1_INDEX);
|
||||
const gert::StorageShape* x2_shape = context->GetInputShape(INPUT_X2_INDEX);
|
||||
const gert::StorageShape* gamma_shape = context->GetInputShape(INPUT_GAMMA_INDEX);
|
||||
const gert::StorageShape* y_shape = context->GetOutputShape(OUTPUT_Y_INDEX);
|
||||
const gert::StorageShape* rstd_shape = context->GetOutputShape(OUTPUT_RSTD_INDEX);
|
||||
const gert::StorageShape* x_shape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x1_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x2_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, gamma_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, y_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, rstd_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x_shape);
|
||||
|
||||
size_t x1DimNum = x1_shape->GetStorageShape().GetDimNum();
|
||||
size_t x2DimNum = x2_shape->GetStorageShape().GetDimNum();
|
||||
size_t gammaDimNum = gamma_shape->GetStorageShape().GetDimNum();
|
||||
size_t yDimNum = y_shape->GetStorageShape().GetDimNum();
|
||||
size_t rstdDimNum = rstd_shape->GetStorageShape().GetDimNum();
|
||||
size_t xDimNum = x_shape->GetStorageShape().GetDimNum();
|
||||
|
||||
OP_CHECK_IF(
|
||||
x1DimNum > MAX_DIM_NUM || x1DimNum < MIN_DIM_X,
|
||||
OP_LOGE(context, "Input x1's dim num should not greater than 8 or smaller than 1."),
|
||||
return false);
|
||||
OP_CHECK_IF(
|
||||
gammaDimNum > MAX_DIM_NUM || gammaDimNum < MIN_DIM_GAMMA,
|
||||
OP_LOGE(context, "Input gamma's dim num should not greater than 8 or smaller than 1."),
|
||||
return false);
|
||||
OP_CHECK_IF(
|
||||
x1DimNum != yDimNum, OP_LOGE(context, "Input x's dim num must equal to output y's dim num."),
|
||||
return false);
|
||||
|
||||
OP_CHECK_IF(
|
||||
x1DimNum != x2DimNum,
|
||||
OP_LOGE(context, "Input x2/x1 shape invaild, dim num is not equal x1 dim."), return false);
|
||||
OP_CHECK_IF(
|
||||
(yDimNum != xDimNum) || (xDimNum != x1DimNum) || (rstdDimNum != x1DimNum),
|
||||
OP_LOGE(context, "Output y/x/rstd shape invaild, dim num is not equal x1 dim."), return false);
|
||||
OP_CHECK_IF(
|
||||
x1DimNum < gammaDimNum, OP_LOGE(context, "X1 dim num should not be smaller than gamma dim num."),
|
||||
return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckInputOutputShape(const gert::TilingContext* context)
|
||||
{
|
||||
OP_CHECK_IF(!CheckInputOutputDim(context), OP_LOGE(context, "Input Dim invalid."), return false);
|
||||
const gert::StorageShape* x1_shape = context->GetInputShape(INPUT_X1_INDEX);
|
||||
const gert::StorageShape* x2_shape = context->GetInputShape(INPUT_X2_INDEX);
|
||||
const gert::StorageShape* gamma_shape = context->GetInputShape(INPUT_GAMMA_INDEX);
|
||||
const gert::StorageShape* y_shape = context->GetOutputShape(OUTPUT_Y_INDEX);
|
||||
const gert::StorageShape* rstd_shape = context->GetOutputShape(OUTPUT_RSTD_INDEX);
|
||||
const gert::StorageShape* x_shape = context->GetOutputShape(OUTPUT_X_INDEX);
|
||||
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x1_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x2_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, gamma_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, y_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, rstd_shape);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, x_shape);
|
||||
|
||||
size_t x1DimNum = x1_shape->GetStorageShape().GetDimNum();
|
||||
size_t gammaDimNum = gamma_shape->GetStorageShape().GetDimNum();
|
||||
|
||||
for (uint32_t i = 0; i < x1DimNum; i++) {
|
||||
OP_CHECK_IF(
|
||||
x1_shape->GetStorageShape().GetDim(i) == 0, OP_LOGE(context, "Input x1 shape can not be 0."),
|
||||
return false);
|
||||
OP_CHECK_IF(
|
||||
x2_shape->GetStorageShape().GetDim(i) != x1_shape->GetStorageShape().GetDim(i),
|
||||
OP_LOGE(context, "Input x2/x1 shape invaild, shape is not equal x1 shape."), return false);
|
||||
OP_CHECK_IF(
|
||||
(y_shape->GetStorageShape().GetDim(i) != x1_shape->GetStorageShape().GetDim(i)) ||
|
||||
(x_shape->GetStorageShape().GetDim(i) != x1_shape->GetStorageShape().GetDim(i)),
|
||||
OP_LOGE(context, "Input y/x shape invaild, shape is not equal x1 shape."), return false);
|
||||
}
|
||||
for (uint32_t i = 0; i < x1DimNum - gammaDimNum; i++) {
|
||||
OP_CHECK_IF(
|
||||
rstd_shape->GetStorageShape().GetDim(i) != x2_shape->GetStorageShape().GetDim(i),
|
||||
OP_LOGE(context, "Output rstd shape invaild, shape is not equal x1 first few dim."),
|
||||
return false);
|
||||
}
|
||||
for (uint32_t i = 0; i < gammaDimNum; i++) {
|
||||
OP_CHECK_IF(
|
||||
gamma_shape->GetStorageShape().GetDim(i) != x1_shape->GetStorageShape().GetDim(x1DimNum - gammaDimNum + i),
|
||||
OP_LOGE(context, "Input gamma shape invaild, gamma shape is not equal x1 last few dim."),
|
||||
return false);
|
||||
OP_CHECK_IF(
|
||||
rstd_shape->GetStorageShape().GetDim(x1DimNum - 1 - i) != 1,
|
||||
OP_LOGE(context, "Output rstd shape invaild, last few dim is not equal to 1."),
|
||||
return false);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void GetCompileParameters(
|
||||
gert::TilingContext* context, uint32_t& numCore, uint64_t& ubSize)
|
||||
{
|
||||
auto ptrCompileInfo = reinterpret_cast<const AddRmsNormBiasCompileInfo*>(context->GetCompileInfo());
|
||||
if (ptrCompileInfo == nullptr) {
|
||||
auto ascendc_platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
addRmsNormBiasSocVersion = ascendc_platform.GetSocVersion();
|
||||
numCore = ascendc_platform.GetCoreNumAiv();
|
||||
ascendc_platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||
} else {
|
||||
numCore = ptrCompileInfo->totalCoreNum;
|
||||
ubSize = ptrCompileInfo->totalUbSize;
|
||||
addRmsNormBiasSocVersion = ptrCompileInfo->socVersion;
|
||||
}
|
||||
ubSize -= UB_USED;
|
||||
}
|
||||
|
||||
static void CalculateRowAndColParameters(gert::TilingContext* context, uint32_t& numRow, uint32_t& numCol)
|
||||
{
|
||||
const gert::Shape x1_shape = context->GetInputShape(0)->GetStorageShape();
|
||||
const size_t gammaIndex = 2;
|
||||
const gert::Shape gamma_shape = context->GetInputShape(gammaIndex)->GetStorageShape();
|
||||
numCol = gamma_shape.GetShapeSize();
|
||||
|
||||
const size_t x1DimNum = x1_shape.GetDimNum();
|
||||
const size_t gammaDimNum = gamma_shape.GetDimNum();
|
||||
numRow = 1U;
|
||||
for (size_t i = 0; i < x1DimNum - gammaDimNum; ++i) {
|
||||
numRow *= x1_shape.GetDim(i);
|
||||
}
|
||||
}
|
||||
|
||||
static ge::graphStatus GetEpsilonParameter(gert::TilingContext* context, float& epsilon)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||
epsilon = *attrs->GetFloat(0);
|
||||
OP_CHECK_IF(
|
||||
epsilon < 0, OP_LOGE(context, "Epsilon less than zero, please check."), return ge::GRAPH_FAILED);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static void CalculateBlockParameters(
|
||||
uint32_t numRow, uint32_t numCore, uint32_t& blockFactor, uint32_t& latsBlockFactor, uint32_t& useCoreNum)
|
||||
{
|
||||
blockFactor = 1U;
|
||||
uint32_t tileNum = CeilDiv(numRow, numCore * blockFactor);
|
||||
blockFactor *= tileNum;
|
||||
useCoreNum = CeilDiv(numRow, blockFactor);
|
||||
latsBlockFactor = numRow - blockFactor * (useCoreNum - 1);
|
||||
}
|
||||
|
||||
static ge::DataType SetDataTypeParameters(gert::TilingContext* context, uint32_t& dtype_key, uint32_t& data_per_block)
|
||||
{
|
||||
auto data_type = context->GetInputDesc(0)->GetDataType();
|
||||
dtype_key = DTYPE_KEY_FP16;
|
||||
SetByDtype(data_type, dtype_key, data_per_block);
|
||||
return data_type;
|
||||
}
|
||||
|
||||
static void DetermineModeParameters(
|
||||
AddRMSNormBiasTilingData* tiling,
|
||||
uint32_t numCol, uint32_t& ubFactor, uint32_t& rowFactor, uint32_t blockFactor,
|
||||
uint32_t latsBlockFactor, ge::DataType dataType, uint32_t dtypKey, uint64_t ubSize,
|
||||
uint32_t dataPerBlock, uint32_t numColAlign, uint32_t& modeKey, uint32_t isPerformance)
|
||||
{
|
||||
if (numCol > ubFactor) {
|
||||
modeKey = MODE_SPLIT_D;
|
||||
ubFactor = tiling->get_nullptr_beta() == 1 ? ((dataType == ge::DT_FLOAT) ? UB_FACTOR_B32_CUTD : UB_FACTOR_B16_CUTD) : ((dataType == ge::DT_FLOAT) ? UB_FACTOR_B32_CUTD_WITH_BETA : UB_FACTOR_B16_CUTD_WITH_BETA);
|
||||
uint32_t colTileNum = CeilDiv(numCol, ubFactor);
|
||||
ubFactor = CeilDiv(numCol, colTileNum * dataPerBlock) * dataPerBlock;
|
||||
} else if (blockFactor == 1 && addRmsNormBiasSocVersion != platform_ascendc::SocVersion::ASCEND310P) {
|
||||
modeKey = MODE_SINGLE_N;
|
||||
} else if (((tiling->get_nullptr_beta() == 1 && numColAlign <= SMALL_REDUCE_NUM) || (tiling->get_nullptr_beta() == 0 && numColAlign <= SMALL_REDUCE_NUM_WITH_BETA)) && addRmsNormBiasSocVersion != platform_ascendc::SocVersion::ASCEND310P) {
|
||||
modeKey = MODE_MERGE_N;
|
||||
uint64_t numColAlignWeight = tiling->get_nullptr_beta() == 1 ? ((dtypKey == DTYPE_KEY_FP32) ? FP32_WEIGHT : OTHER_WEIGHT) : ((dtypKey == DTYPE_KEY_FP32) ? FP32_WEIGHT_WITH_BETA : OTHER_WEIGHT_WITH_BETA);
|
||||
rowFactor = static_cast<uint32_t>(ubSize) /
|
||||
(numColAlign * static_cast<uint32_t>(numColAlignWeight) + static_cast<uint32_t>(DIV_FACTOR));
|
||||
ubFactor = rowFactor * numColAlign;
|
||||
|
||||
uint32_t mulLoopFp32 = numColAlign / 64;
|
||||
uint32_t mulTailFp32 = numColAlign - mulLoopFp32 * 64;
|
||||
uint8_t dstRepStrideFp32 = numColAlign / 8;
|
||||
|
||||
uint32_t mulLoopFp16 = numColAlign / 128;
|
||||
uint32_t mulTailFp16 = numColAlign - mulLoopFp16 * 128;
|
||||
uint8_t dstRepStrideFp16 = numColAlign / 16;
|
||||
|
||||
tiling->set_is_performance(isPerformance);
|
||||
tiling->set_mul_loop_fp32(mulLoopFp32);
|
||||
tiling->set_mul_tail_fp32(mulTailFp32);
|
||||
tiling->set_dst_rep_stride_fp32(dstRepStrideFp32);
|
||||
tiling->set_mul_loop_fp16(mulLoopFp16);
|
||||
tiling->set_mul_tail_fp16(mulTailFp16);
|
||||
tiling->set_dst_rep_stride_fp16(dstRepStrideFp16);
|
||||
} else if ((dataType == ge::DT_FLOAT16 || isPerformance == 1) && numCol == numColAlign) {
|
||||
modeKey = MODE_MULTI_N;
|
||||
rowFactor = (static_cast<uint32_t>(ubSize) - static_cast<uint32_t>(USE_SIZE) -
|
||||
numColAlign * static_cast<uint32_t>(tiling->get_nullptr_beta() == 1 ? NUM : NUM_WITH_BETA)) /
|
||||
(numColAlign * BLOCK_ALIGN_NUM + static_cast<uint32_t>(FLOAT_PER_REPEAT));
|
||||
ubFactor = rowFactor * numColAlign;
|
||||
if (rowFactor == 0U) {
|
||||
modeKey = MODE_NORMAL;
|
||||
rowFactor = FLOAT_PER_REPEAT;
|
||||
ubFactor = UB_FACTOR_B16;
|
||||
}
|
||||
}
|
||||
uint32_t rowLoop = CeilDiv(blockFactor, rowFactor);
|
||||
uint32_t lastBlockRowLoop = CeilDiv(latsBlockFactor, rowFactor);
|
||||
uint32_t rowTail = blockFactor - (rowLoop - 1) * rowFactor;
|
||||
uint32_t lastBlockRowTail = latsBlockFactor - (lastBlockRowLoop - 1) * rowFactor;
|
||||
tiling->set_row_loop(rowLoop);
|
||||
tiling->set_last_block_row_loop(lastBlockRowLoop);
|
||||
tiling->set_row_tail(rowTail);
|
||||
tiling->set_last_block_row_tail(lastBlockRowTail);
|
||||
}
|
||||
|
||||
static void SetTilingParameters(
|
||||
AddRMSNormBiasTilingData* tiling, uint32_t num_row, uint32_t num_col, uint32_t numColAlign,
|
||||
uint32_t block_factor, uint32_t latsBlockFactor, uint32_t row_factor,
|
||||
uint32_t ub_factor, float epsilon)
|
||||
{
|
||||
const float avg_factor = (num_col == 0) ? 0 : 1.0f / num_col;
|
||||
tiling->set_num_row(num_row);
|
||||
tiling->set_num_col(num_col);
|
||||
tiling->set_num_col_align(numColAlign);
|
||||
tiling->set_block_factor(block_factor);
|
||||
tiling->set_last_block_factor(latsBlockFactor);
|
||||
tiling->set_row_factor(row_factor);
|
||||
tiling->set_ub_factor(ub_factor);
|
||||
tiling->set_epsilon(epsilon);
|
||||
tiling->set_avg_factor(avg_factor);
|
||||
}
|
||||
|
||||
static void SaveTilingData(
|
||||
gert::TilingContext* context, AddRMSNormBiasTilingData* tiling, uint32_t dtype_key, uint32_t mode_key)
|
||||
{
|
||||
const uint32_t tiling_key = dtype_key * 10 + mode_key;
|
||||
context->SetTilingKey(tiling_key);
|
||||
tiling->SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling->GetDataSize());
|
||||
}
|
||||
|
||||
static void SetWorkspaceSize(gert::TilingContext* context)
|
||||
{
|
||||
constexpr size_t sysWorkspaceSize = 16 * 1024 * 1024;
|
||||
constexpr size_t usrSize = 256;
|
||||
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = usrSize + sysWorkspaceSize;
|
||||
}
|
||||
|
||||
static void LogTilingResults(
|
||||
gert::TilingContext* context, AddRMSNormBiasTilingData* tiling, uint32_t mode_key, uint32_t dtype_key,
|
||||
uint32_t use_core_num, float epsilon)
|
||||
{
|
||||
OPS_LOG_I(context, "Tiling Key: %u", dtype_key * TEN + mode_key);
|
||||
OPS_LOG_I(context, "Block Dim: %u", use_core_num);
|
||||
OPS_LOG_I(context, "usr Workspace: 256");
|
||||
OPS_LOG_I(
|
||||
context,
|
||||
"num_row: %d, num_col: %d, block_factor: %d, row_factor: %d, ub_factor: %d, epsilon: %f, avg_factor: %f",
|
||||
tiling->get_num_row(), tiling->get_num_col(), tiling->get_block_factor(), tiling->get_row_factor(),
|
||||
tiling->get_ub_factor(), epsilon, tiling->get_avg_factor());
|
||||
}
|
||||
|
||||
static ge::graphStatus Tiling4AddRmsNormBias(gert::TilingContext* context)
|
||||
{
|
||||
OP_LOGI("Tiling4AddRmsNormBias", "Enter Tiling4AddRmsNormBias");
|
||||
OPS_LOG_D(context, "Tiling4AddRmsNormBias1 running. \n");
|
||||
OP_CHECK_IF(
|
||||
!CheckInputOutputShape(context), OP_LOGE(context, "Input shape invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
AddRMSNormBiasTilingData tiling;
|
||||
|
||||
auto betaDesc = context->GetOptionalInputDesc(INPUT_BETA_INDEX);
|
||||
tiling.set_nullptr_beta(betaDesc == nullptr ? 1 : 0);
|
||||
|
||||
uint32_t num_core;
|
||||
uint64_t ub_size;
|
||||
GetCompileParameters(context, num_core, ub_size);
|
||||
uint32_t num_row;
|
||||
uint32_t num_col;
|
||||
CalculateRowAndColParameters(context, num_row, num_col);
|
||||
float epsilon = 0;
|
||||
GetEpsilonParameter(context, epsilon);
|
||||
if (epsilon < 0) {
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
uint32_t block_factor;
|
||||
uint32_t latsBlockFactor;
|
||||
uint32_t use_core_num;
|
||||
CalculateBlockParameters(num_row, num_core, block_factor, latsBlockFactor, use_core_num);
|
||||
context->SetBlockDim(use_core_num);
|
||||
uint32_t dtype_key;
|
||||
uint32_t data_per_block;
|
||||
ge::DataType data_type = SetDataTypeParameters(context, dtype_key, data_per_block);
|
||||
uint32_t mode_key = MODE_NORMAL;
|
||||
uint32_t row_factor = 64;
|
||||
uint32_t ub_factor = betaDesc == nullptr ? ((dtype_key == DTYPE_KEY_FP32) ? UB_FACTOR_B32 : UB_FACTOR_B16) : ((dtype_key == DTYPE_KEY_FP32) ? UB_FACTOR_B32_WITH_BETA : UB_FACTOR_B16_WITH_BETA);
|
||||
uint32_t numColAlign = CeilDiv(num_col, data_per_block) * data_per_block;
|
||||
const gert::Shape x1_shape = context->GetInputShape(0)->GetStorageShape();
|
||||
const gert::Shape gamma_shape = context->GetInputShape(2)->GetStorageShape();
|
||||
uint8_t isPerformance = getPerformanceFlag(num_col, x1_shape, gamma_shape, dtype_key);
|
||||
DetermineModeParameters(
|
||||
&tiling,
|
||||
num_col, ub_factor, row_factor, block_factor, latsBlockFactor,
|
||||
data_type, dtype_key, ub_size, data_per_block,
|
||||
numColAlign, mode_key, isPerformance);
|
||||
SetTilingParameters(&tiling, num_row, num_col, numColAlign, block_factor, latsBlockFactor, row_factor, ub_factor, epsilon);
|
||||
SaveTilingData(context, &tiling, dtype_key, mode_key);
|
||||
SetWorkspaceSize(context);
|
||||
LogTilingResults(context, &tiling, mode_key, dtype_key, use_core_num, epsilon);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingPrepare4AddRmsNormBias(gert::TilingParseContext* context)
|
||||
{
|
||||
OPS_LOG_D(context, "TilingPrepare4AddRmsNormBias running. \n");
|
||||
OP_LOGI(context, "TilingPrepare4AddRmsNormBias running.");
|
||||
auto compileInfo = context->GetCompiledInfo<AddRmsNormBiasCompileInfo>();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
|
||||
auto platformInfo = context->GetPlatformInfo();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
||||
|
||||
compileInfo->socVersion = ascendcPlatform.GetSocVersion();
|
||||
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNumAiv();
|
||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfo->totalUbSize);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(AddRmsNormBias).Tiling(Tiling4AddRmsNormBias).TilingParse<AddRmsNormBiasCompileInfo>(TilingPrepare4AddRmsNormBias);
|
||||
|
||||
} // namespace optiling
|
||||
53
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_tiling.h
Normal file
53
csrc/add_rms_norm_bias/op_host/add_rms_norm_bias_tiling.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_RUNTIME_ADD_RMS_NORM_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_RUNTIME_ADD_RMS_NORM_H_
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "error_log.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "platform/platform_infos_def.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(AddRMSNormBiasTilingData)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, num_row);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, num_col);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, block_factor);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, row_factor);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, ub_factor);
|
||||
TILING_DATA_FIELD_DEF(float, epsilon);
|
||||
TILING_DATA_FIELD_DEF(float, avg_factor);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, num_col_align);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, last_block_factor);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, row_loop);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, last_block_row_loop);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, row_tail);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, last_block_row_tail);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mul_loop_fp32);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mul_tail_fp32);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dst_rep_stride_fp32);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mul_loop_fp16);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, mul_tail_fp16);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, dst_rep_stride_fp16);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, is_performance);
|
||||
TILING_DATA_FIELD_DEF(uint32_t, nullptr_beta);
|
||||
END_TILING_DATA_DEF;
|
||||
|
||||
struct AddRmsNormBiasCompileInfo {
|
||||
uint32_t totalCoreNum = 0;
|
||||
uint64_t totalUbSize = 0;
|
||||
platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910_95;
|
||||
};
|
||||
|
||||
REGISTER_TILING_DATA_CLASS(AddRmsNormBias, AddRMSNormBiasTilingData)
|
||||
} // namespace optiling
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_RUNTIME_ADD_RMS_NORM_BIAS_H_
|
||||
71
csrc/add_rms_norm_bias/op_host/error_log.h
Normal file
71
csrc/add_rms_norm_bias/op_host/error_log.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
|
||||
#include <string>
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
#define OP_LOGI(opname, ...)
|
||||
#define OP_LOGW(opname, ...) \
|
||||
do { \
|
||||
printf("[WARN][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE_WITHOUT_REPORT(opname, ...) \
|
||||
do { \
|
||||
printf("[ERRORx][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGE(opname, ...) \
|
||||
do { \
|
||||
printf("[ERROR][%s] ", (opname), ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGD(opname, ...)
|
||||
|
||||
namespace optiling {
|
||||
|
||||
#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \
|
||||
do { \
|
||||
OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
|
||||
#define OP_CHECK_IF(cond, log_func, expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
OP_LOGE(context->GetNodeType(), "%s is null", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
template <typename T>
|
||||
T CeilAlign(T a, T b)
|
||||
{
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T CeilDiv(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return a;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_
|
||||
72
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.cpp
Normal file
72
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "add_rms_norm_bias.h"
|
||||
#include "add_rms_norm_bias_split_d.h"
|
||||
#include "add_rms_norm_bias_merge_n.h"
|
||||
#include "add_rms_norm_bias_multi_n.h"
|
||||
#include "add_rms_norm_bias_single_n.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
#define GENERAL_OP_IMPL(templateClass, ...) \
|
||||
do { \
|
||||
templateClass<__VA_ARGS__> op(&pipe); \
|
||||
op.Init(x1, x2, gamma, beta, y, rstd, x, &tilingData); \
|
||||
op.Process(); \
|
||||
} while (0)
|
||||
|
||||
extern "C" __global__ __aicore__ void add_rms_norm_bias(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
TPipe pipe;
|
||||
GET_TILING_DATA(tilingData, tiling);
|
||||
if (TILING_KEY_IS(10)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, half);
|
||||
} else if (TILING_KEY_IS(20)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, float);
|
||||
} else if (TILING_KEY_IS(30)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBias, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(11)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, half);
|
||||
} else if (TILING_KEY_IS(21)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, float);
|
||||
} else if (TILING_KEY_IS(31)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSplitD, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(12)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, half);
|
||||
} else if (TILING_KEY_IS(22)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, float);
|
||||
} else if (TILING_KEY_IS(32)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMergeN, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(13)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, half);
|
||||
} else if (TILING_KEY_IS(23)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, float);
|
||||
} else if (TILING_KEY_IS(33)) {
|
||||
#if !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasSingleN, bfloat16_t);
|
||||
#endif
|
||||
} else if (TILING_KEY_IS(14)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, half);
|
||||
} else if (TILING_KEY_IS(34)) {
|
||||
GENERAL_OP_IMPL(KernelAddRmsNormBiasMultiN, bfloat16_t);
|
||||
}
|
||||
}
|
||||
368
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.h
Normal file
368
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias.h
Normal file
@@ -0,0 +1,368 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias.h
|
||||
* \brief add rms norm bias file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_H_
|
||||
#define ADD_RMS_NORM_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBias {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBias(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
|
||||
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
|
||||
|
||||
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
|
||||
SubProcess(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
SubProcess(i_o_max - 1, row_tail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
uint32_t gm_bias = (i_o * rowFactor + i_i) * numCol;
|
||||
CopyIn(gm_bias);
|
||||
Compute(i_i, gammaLocal, betaLocal, rstdLocal);
|
||||
CopyOutY(gm_bias);
|
||||
}
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn(uint32_t gm_bias)
|
||||
{
|
||||
LocalTensor<T> x1Local_in = inQueueX.AllocTensor<T>();
|
||||
LocalTensor<T> x2Local = sqxBuf.Get<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
x2Local = x2Local[ubFactor];
|
||||
}
|
||||
|
||||
DataCopyCustom<T>(x1Local_in, x1Gm[gm_bias], numCol);
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], numCol);
|
||||
inQueueX.EnQue(x1Local_in);
|
||||
auto x1Local = inQueueX.DeQue<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
Add(xLocal, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2_fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1_fp32, x1_fp32, x2_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(xLocal, x1Local, (float)0, numCol);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
|
||||
// CopyOut x1 + x2
|
||||
outQueueY.EnQue(xLocal);
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[gm_bias], x_out, numCol);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<float> gammaLocal, LocalTensor<float> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> xLocal = inQueueX.AllocTensor<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Mul(sqx, xLocal, xLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
|
||||
Muls(yLocal, xLocal, rstdValue, numCol);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(yLocal, betaLocal, yLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<float>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(
|
||||
uint32_t inner_progress, LocalTensor<bfloat16_t> gammaLocal, LocalTensor<bfloat16_t> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, numCol); // gamma_fp32 reuse sqx
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(x_fp32, x_fp32, sqx, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x_fp32, x_fp32, sqx, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
|
||||
outQueueY.EnQue<bfloat16_t>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(uint32_t inner_progress, LocalTensor<half> gammaLocal, LocalTensor<half> betaLocal, LocalTensor<float> rstdLocal)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Adds(sqx, sqx, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(sqx, sqx, 1);
|
||||
Duplicate(reduce_buf_local, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqx, reduce_buf_local, sqx, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = sqx.GetValue(0);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
rstdLocal.SetValue(inner_progress, rstdValue);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, numCol);
|
||||
|
||||
event_t event_v_mte = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
WaitFlag<HardEvent::V_MTE2>(event_v_mte);
|
||||
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(yLocal, betaLocal, yLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<half>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, numCol);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
|
||||
#endif
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // ADD_RMS_NORM_BIAS_H_
|
||||
471
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_merge_n.h
Normal file
471
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_merge_n.h
Normal file
@@ -0,0 +1,471 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_merge_n.h
|
||||
* \brief add rms norm bias merge n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasMergeN {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasMergeN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->numColAlign = tiling->num_col_align;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = tiling->avg_factor;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
this->rowLoop = tiling->row_loop;
|
||||
this->rowTail = tiling->row_tail;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = tiling->last_block_factor;
|
||||
this->rowLoop = tiling->last_block_row_loop;
|
||||
this->rowTail = tiling->last_block_row_tail;
|
||||
}
|
||||
this->mulLoopFp32 = tiling->mul_loop_fp32;
|
||||
this->mulTailFp32 = tiling->mul_tail_fp32;
|
||||
this->dstRepStrideFp32 = tiling->dst_rep_stride_fp32;
|
||||
this->mulLoopFp16 = tiling->mul_loop_fp16;
|
||||
this->mulTailFp16 = tiling->mul_tail_fp16;
|
||||
this->dstRepStrideFp16 = tiling->dst_rep_stride_fp16;
|
||||
this->isPerformance = tiling->is_performance;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
#else
|
||||
Ppipe->InitBuffer(rstdBuf, rowFactor * sizeof(float));
|
||||
#endif
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(tmpBuf, rowFactor * NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
|
||||
MainCompute(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
MainCompute(rowLoop - 1, rowTail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void MainCompute(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
uint32_t gm_bias = i_o * rowFactor * numCol;
|
||||
uint32_t elementNum = calc_row_num * numColAlign;
|
||||
CopyInX(gm_bias, calc_row_num);
|
||||
LocalTensor<T> xLocal = ComputeX(elementNum);
|
||||
CopyOutX(gm_bias, calc_row_num);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
#else
|
||||
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num, elementNum);
|
||||
#endif
|
||||
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num, elementNum);
|
||||
CopyOutY(gm_bias, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num, numCol);
|
||||
}
|
||||
inQueueX.EnQue(x1Local);
|
||||
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num, numCol);
|
||||
}
|
||||
inQueueX.EnQue(x2Local);
|
||||
}
|
||||
|
||||
__aicore__ inline LocalTensor<T> ComputeX(uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, bfloat16_t>::value) {
|
||||
Add(xLocal, x1Local, x2Local, elementNum);
|
||||
} else {
|
||||
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, elementNum);
|
||||
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1Fp32, x1Fp32, x2Fp32, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, elementNum);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
inQueueX.FreeTensor(x2Local);
|
||||
outQueueY.EnQue(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
return xLocal;
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
// CopyOut x1 + x2
|
||||
auto xOut = outQueueY.DeQue<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(xGm[gm_bias], xOut, calc_row_num, numCol);
|
||||
}
|
||||
outQueueY.FreeTensor(xOut);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
|
||||
if constexpr (!is_same<T, float>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, x_fp32, x_fp32, elementNum);
|
||||
} else {
|
||||
Mul(sqx, xLocal, xLocal, elementNum);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
ReduceSumMultiN(rstdLocal, sqx, tmpLocal, calc_row_num, numCol, numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(rstdLocal, rstdLocal, calc_row_num);
|
||||
Duplicate(tmpLocal, ONE, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Div(rstdLocal, tmpLocal, rstdLocal, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num, uint32_t elementNum)
|
||||
{
|
||||
LocalTensor<float> tmpLocal = tmpBuf.Get<float>();
|
||||
uint32_t splidRow = 240;
|
||||
uint32_t rowRepeatLoop1 = calc_row_num / splidRow;
|
||||
uint32_t rowRepeatTail1 = calc_row_num - rowRepeatLoop1 * splidRow;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop1; r_i ++) {
|
||||
Brcb(tmpLocal[r_i * splidRow * MOV_8], rstdLocal[r_i * splidRow], splidRow, {1, 8});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
if(rowRepeatTail1 > 0) {
|
||||
Brcb(tmpLocal[rowRepeatLoop1 * splidRow * MOV_8], rstdLocal[rowRepeatLoop1 * splidRow], rowRepeatTail1, {1, 8});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, float>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
repeatByRow<float>(x_fp32, x_fp32, tmpLocal, calc_row_num, ONE_UINT);
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, elementNum);
|
||||
} else {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
|
||||
}
|
||||
} else {
|
||||
repeatByRow<float>(yLocal, xLocal, tmpLocal, calc_row_num, ONE_UINT);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
repeatByRow<half>(yLocal, yLocal, gammaLocal, calc_row_num, TWO_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
addRepeatByRow<half>(yLocal, yLocal, betaLocal, calc_row_num, TWO_UINT);
|
||||
}
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, elementNum);
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
repeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, elementNum);
|
||||
PipeBarrier<PIPE_V>();
|
||||
addRepeatByRow<float>(x_fp32, x_fp32, sqx, calc_row_num, THREE_UINT);
|
||||
}
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, elementNum);
|
||||
} else {
|
||||
repeatByRow<float>(yLocal, yLocal, gammaLocal, calc_row_num, THREE_UINT);
|
||||
if (!this->nullptrBeta) {
|
||||
addRepeatByRow<float>(yLocal, yLocal, betaLocal, calc_row_num, THREE_UINT);
|
||||
}
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<T>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
if (isNumColAlign) {
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
|
||||
} else {
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num, numCol);
|
||||
}
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
DataCopyCustom<float>(rstdGm[outer_progress * rowFactor], rstdLocal, num);
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void repeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
|
||||
{
|
||||
// TWO_UINT=gammaFp16 ONE_UINT=rstd
|
||||
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
|
||||
if (type == TWO_UINT) {
|
||||
strideParams[0] = mulLoopFp16;
|
||||
strideParams[1] = mulTailFp16;
|
||||
strideParams[2] = 128;
|
||||
strideParams[4] = dstRepStrideFp16;
|
||||
} else if (type == ONE_UINT) {
|
||||
strideParams[3] = 0;
|
||||
strideParams[5] = 1;
|
||||
}
|
||||
uint32_t singlT = 255;
|
||||
uint32_t rowRepeatLoop = calc_row_num / singlT;
|
||||
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
|
||||
uint32_t offset2 = 0;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
|
||||
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
|
||||
mulRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
|
||||
}
|
||||
if(rowRepeatTail > 0) {
|
||||
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
|
||||
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
|
||||
mulRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void mulRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
|
||||
{
|
||||
uint32_t mulLoop = strideParams[0];
|
||||
uint32_t mulTail = strideParams[1];
|
||||
uint32_t strideNum = strideParams[2];
|
||||
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
|
||||
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
|
||||
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
|
||||
if(src1BlkStride == 0) {
|
||||
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
|
||||
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(mulTail > 0) {
|
||||
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local, mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
for (uint32_t m_i = 0; m_i < mulLoop; m_i++) {
|
||||
Mul(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(mulTail > 0) {
|
||||
Mul(dstLocal[mulLoop * strideNum], src1Local[mulLoop * strideNum], src2Local[mulLoop * strideNum], mulTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void addRepeatByRow(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calc_row_num, uint32_t type)
|
||||
{
|
||||
// TWO_UINT=gammaFp16 ONE_UINT=rstd
|
||||
uint32_t strideParams[6] = {mulLoopFp32, mulTailFp32, 64, 1, dstRepStrideFp32, 0};
|
||||
if (type == TWO_UINT) {
|
||||
strideParams[0] = mulLoopFp16;
|
||||
strideParams[1] = mulTailFp16;
|
||||
strideParams[2] = 128;
|
||||
strideParams[4] = dstRepStrideFp16;
|
||||
} else if (type == ONE_UINT) {
|
||||
strideParams[3] = 0;
|
||||
strideParams[5] = 1;
|
||||
}
|
||||
uint32_t singlT = 255;
|
||||
uint32_t rowRepeatLoop = calc_row_num / singlT;
|
||||
uint32_t rowRepeatTail = calc_row_num - rowRepeatLoop * singlT;
|
||||
uint32_t offset2 = 0;
|
||||
for(uint32_t r_i = 0; r_i < rowRepeatLoop; r_i ++) {
|
||||
offset2 = type == 1 ? (r_i * singlT * MOV_8) : 0;
|
||||
addRepeat<U>(dstLocal[r_i * singlT * numColAlign], src1Local[r_i * singlT * numColAlign], src2Local[offset2], singlT, strideParams);
|
||||
}
|
||||
if(rowRepeatTail > 0) {
|
||||
offset2 = type == 1 ? (rowRepeatLoop * singlT * MOV_8) : 0;
|
||||
uint32_t offset1 = rowRepeatLoop * singlT * numColAlign;
|
||||
addRepeat<U>(dstLocal[offset1], src1Local[offset1], src2Local[offset2], rowRepeatTail, strideParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__aicore__ inline void addRepeat(const LocalTensor<U>& dstLocal, const LocalTensor<U>& src1Local, const LocalTensor<U>& src2Local, uint32_t calcRowNum, uint32_t strideParams[6])
|
||||
{
|
||||
uint32_t addLoop = strideParams[0];
|
||||
uint32_t addTail = strideParams[1];
|
||||
uint32_t strideNum = strideParams[2];
|
||||
uint8_t src1BlkStride = static_cast<uint8_t>(strideParams[3]);
|
||||
uint8_t dstRepStride = static_cast<uint8_t>(strideParams[4]);
|
||||
uint8_t src1RepStride = static_cast<uint8_t>(strideParams[5]);
|
||||
if(src1BlkStride == 0) {
|
||||
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
|
||||
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local, strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(addTail > 0) {
|
||||
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local, addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
} else {
|
||||
for (uint32_t m_i = 0; m_i < addLoop; m_i++) {
|
||||
Add(dstLocal[m_i * strideNum], src1Local[m_i * strideNum], src2Local[m_i * strideNum], strideNum, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if(addTail > 0) {
|
||||
Add(dstLocal[addLoop * strideNum], src1Local[addLoop * strideNum], src2Local[addLoop * strideNum], addTail, calcRowNum, {1, 1, src1BlkStride, dstRepStride, dstRepStride, src1RepStride});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
#else
|
||||
TBuf<TPosition::VECCALC> rstdBuf;
|
||||
#endif
|
||||
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> tmpBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t numColAlign;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
#if (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
bool isNumColAlign = true;
|
||||
#else
|
||||
bool isNumColAlign = false;
|
||||
#endif
|
||||
uint8_t isPerformance = 0;
|
||||
uint32_t rowLoop = 1;
|
||||
uint32_t rowTail = 0;
|
||||
uint32_t mulLoopFp32;
|
||||
uint32_t mulTailFp32;
|
||||
uint8_t dstRepStrideFp32;
|
||||
uint32_t mulLoopFp16;
|
||||
uint32_t mulTailFp16;
|
||||
uint8_t dstRepStrideFp16;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_MERGE_N_H_
|
||||
339
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_multi_n.h
Normal file
339
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_multi_n.h
Normal file
@@ -0,0 +1,339 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_multi_n.h
|
||||
* \brief add rms norm bias multi n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_MULTI_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_MULTI_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasMultiN {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasMultiN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->numColAlign = tiling->num_col_align;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = tiling->avg_factor;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
this->rowLoop = tiling->row_loop;
|
||||
this->rowTail = tiling->row_tail;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = tiling->last_block_factor;
|
||||
this->rowLoop = tiling->last_block_row_loop;
|
||||
this->rowTail = tiling->last_block_row_tail;
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes
|
||||
Ppipe->InitBuffer(inQueueX, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, numColAlign * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, numColAlign * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, DOUBLE_BUFFER_NUM, ubFactor * sizeof(T));
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
#else
|
||||
Ppipe->InitBuffer(rstdBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
#endif
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
Ppipe->InitBuffer(offsetBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(uint32_t));
|
||||
}
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyInGammaBeta();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
|
||||
for (uint32_t i = 0; i < rowFactor; i++) {
|
||||
Duplicate(offsetLocal[i * NUM_PER_BLK_FP32], i * ONE_BLK_SIZE, NUM_PER_BLK_FP32);
|
||||
}
|
||||
for (uint32_t i_o = 0; i_o < rowLoop - 1; i_o++) {
|
||||
SubProcessHalf(i_o, rowFactor, gammaLocal, betaLocal);
|
||||
}
|
||||
SubProcessHalf(rowLoop - 1, rowTail, gammaLocal, betaLocal);
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcessHalf(uint32_t i_o, uint32_t calc_row_num, LocalTensor<T>& gammaLocal, LocalTensor<T>& betaLocal)
|
||||
{
|
||||
uint32_t gm_bias = i_o * rowFactor * numCol;
|
||||
CopyInX(gm_bias, calc_row_num);
|
||||
LocalTensor<T> xLocal = ComputeX(calc_row_num);
|
||||
CopyOutX(gm_bias, calc_row_num);
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o * rowFactor, calc_row_num);
|
||||
#else
|
||||
LocalTensor<float> rstdLocal = rstdBuf.Get<float>();
|
||||
ComputeRstd(xLocal, rstdLocal, calc_row_num);
|
||||
#endif
|
||||
ComputeY(xLocal, gammaLocal, betaLocal, rstdLocal, calc_row_num);
|
||||
CopyOutY(gm_bias, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> x1Local = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(x1Local, x1Gm[gm_bias], calc_row_num * numCol);
|
||||
inQueueX.EnQue(x1Local);
|
||||
LocalTensor<T> x2Local = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(x2Local, x2Gm[gm_bias], calc_row_num * numCol);
|
||||
inQueueX.EnQue(x2Local);
|
||||
}
|
||||
|
||||
__aicore__ inline LocalTensor<T> ComputeX(uint32_t calc_row_num)
|
||||
{
|
||||
uint32_t calc_num = calc_row_num * numColAlign;
|
||||
LocalTensor<T> x1Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> x2Local = inQueueX.DeQue<T>();
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (!is_same<T, bfloat16_t>::value) {
|
||||
Add(xLocal, x1Local, x2Local, calc_num);
|
||||
} else {
|
||||
LocalTensor<float> x1Fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2Fp32 = sqxBuf.Get<float>();
|
||||
Cast(x1Fp32, x1Local, RoundMode::CAST_NONE, calc_num);
|
||||
Cast(x2Fp32, x2Local, RoundMode::CAST_NONE, calc_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x1Fp32, x1Fp32, x2Fp32, calc_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1Fp32, RoundMode::CAST_RINT, calc_num);
|
||||
}
|
||||
inQueueX.FreeTensor(x1Local);
|
||||
inQueueX.FreeTensor(x2Local);
|
||||
outQueueY.EnQue(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
return xLocal;
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutX(uint32_t gm_bias, uint32_t calc_row_num)
|
||||
{
|
||||
// CopyOut x1 + x2
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[gm_bias], x_out, calc_row_num * numCol);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta()
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm, numCol);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm, numCol);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<T> xLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Mul(sqx, x_fp32, x_fp32, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Muls(sqx, sqx, avgFactor, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
ReduceSumCustom(rstdLocal[i_i * NUM_PER_BLK_FP32], sqx[i_i * numColAlign], reduce_buf_local, numCol);
|
||||
}
|
||||
Adds(rstdLocal, rstdLocal, epsilon, calc_row_num * NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Sqrt(rstdLocal, rstdLocal, calc_row_num * NUM_PER_BLK_FP32);
|
||||
Duplicate(reduce_buf_local, ONE, NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
int32_t repeatTimes = calc_row_num * NUM_PER_BLK_FP32 / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = calc_row_num * NUM_PER_BLK_FP32 % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Div(rstdLocal, reduce_buf_local, rstdLocal, NUM_PER_REP_FP32, repeatTimes, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Div(rstdLocal[bodyCount], reduce_buf_local, rstdLocal[bodyCount], tailCount, 1, {1, 0, 1, DEFAULT_REPEAT_STRIDE, 0, DEFAULT_REPEAT_STRIDE});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
LocalTensor<T> xLocal, LocalTensor<T> gammaLocal, LocalTensor<T> betaLocal, LocalTensor<float> rstdLocal, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<uint32_t> offsetLocal = offsetBuf.Get<uint32_t>();
|
||||
Gather(rstdLocal, rstdLocal, offsetLocal, ZERO_UINT, calc_row_num * NUM_PER_BLK_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
int32_t repeatTimes = numCol / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = numCol % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Mul(x_fp32[i_i * numColAlign], x_fp32[i_i * numColAlign], rstdLocal[i_i * NUM_PER_BLK_FP32],
|
||||
NUM_PER_REP_FP32, repeatTimes, {1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Mul(x_fp32[i_i * numColAlign + bodyCount], x_fp32[i_i * numColAlign + bodyCount],
|
||||
rstdLocal[i_i * NUM_PER_BLK_FP32], tailCount, 1,
|
||||
{1, 1, 0, DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE, 0});
|
||||
}
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<T> yLocal = outQueueY.AllocTensor<T>();
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Mul(yLocal[i_i * numColAlign], gammaLocal, yLocal[i_i * numColAlign], numCol);
|
||||
}
|
||||
if (!this->nullptrBeta) {
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Add(yLocal[i_i * numColAlign], betaLocal, yLocal[i_i * numColAlign], numCol);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> yfp32 = xFp32Buf.Get<float>();
|
||||
Cast(yfp32, yLocal, RoundMode::CAST_NONE, calc_row_num * numColAlign);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<float> gammaFp32 = sqxBuf.Get<float>();
|
||||
Cast(gammaFp32, gammaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Mul(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(gammaFp32, betaLocal, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
Add(yfp32[i_i * numColAlign], gammaFp32, yfp32[i_i * numColAlign], numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(yLocal, yfp32, RoundMode::CAST_RINT, calc_row_num * numColAlign);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<T>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t progress, uint32_t calc_row_num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[progress], yLocal, calc_row_num * numCol);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
__aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = sizeof(float);
|
||||
copyParams.blockCount = num;
|
||||
DataCopyPad(rstdGm[outer_progress], rstdLocal, copyParams);
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
TQue<QuePosition::VECIN, DOUBLE_BUFFER_NUM> inQueueX;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
#else
|
||||
TBuf<TPosition::VECCALC> rstdBuf;
|
||||
#endif
|
||||
TQue<QuePosition::VECOUT, DOUBLE_BUFFER_NUM> outQueueY;
|
||||
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
TBuf<TPosition::VECCALC> offsetBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
uint32_t numColAlign;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t rowLoop = 1;
|
||||
uint32_t rowTail = 0;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // ADD_RMS_NORM__BIAS_MULTI_N_H_
|
||||
376
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_single_n.h
Normal file
376
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_single_n.h
Normal file
@@ -0,0 +1,376 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_single_n.h
|
||||
* \brief add rms norm bias single n file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
#define ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasSingleN {
|
||||
static constexpr int32_t MAXBUFFER = 195584;
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasSingleN(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = 1; // in this case, blockFactor = 1
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
this->rowWork = 1;
|
||||
blockIdx_ = GetBlockIdx();
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * numCol, numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * numCol, numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * numCol, numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_, 1);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * numCol, numCol);
|
||||
|
||||
Ppipe->InitBuffer(unitBuf, MAXBUFFER); // (192 - 1) * 1024 byte
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
ProcessFp16();
|
||||
} else if constexpr (is_same<T, float>::value) {
|
||||
ProcessFp32();
|
||||
} else {
|
||||
ProcessBf16();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void ProcessFp16()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
|
||||
LocalTensor<T> x1Local = xLocal[0];
|
||||
LocalTensor<T> x2Local = xLocal[ubFactor];
|
||||
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copy gamma
|
||||
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copyout rstd
|
||||
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
#endif
|
||||
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS);
|
||||
WaitFlag<HardEvent::V_S>(eventVS);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::S_V>(eventSV);
|
||||
|
||||
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Mul(x1Local, x1Local, x2Local, numCol);
|
||||
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
}
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessFp32()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> x1Local = ubLocal[0];
|
||||
LocalTensor<T> x2Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copy gamma
|
||||
event_t eventVMTE2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
|
||||
Mul(sqxLocal, x1Local, x1Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
// copyout rstd
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
#endif
|
||||
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS);
|
||||
WaitFlag<HardEvent::V_S>(eventVS);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::S_V>(eventSV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V);
|
||||
Muls(x1Local, x1Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2);
|
||||
Mul(x1Local, x1Local, x2Local, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Add(x1Local, x1Local, x2Local, numCol);
|
||||
}
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessBf16()
|
||||
{
|
||||
LocalTensor<float> ubLocal = unitBuf.Get<float>();
|
||||
LocalTensor<T> xLocal = ubLocal.template ReinterpretCast<T>();
|
||||
LocalTensor<T> x1Local = xLocal[0];
|
||||
LocalTensor<T> x2Local = xLocal[ubFactor];
|
||||
LocalTensor<float> xFp32Local = ubLocal[ubFactor];
|
||||
LocalTensor<float> sqxLocal = ubLocal[ubFactor * 2];
|
||||
LocalTensor<float> tmpLocal = ubLocal[ubFactor * 3];
|
||||
|
||||
DataCopyCustom<T>(x1Local, x1Gm, numCol);
|
||||
event_t eventMTE2V1_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
DataCopyCustom<T>(x2Local, x2Gm, numCol);
|
||||
event_t eventMTE2V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V1_BF16_0);
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventMTE2V2_BF16_0);
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// copy gamma
|
||||
event_t eventVMTE2_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2_BF16_0);
|
||||
|
||||
DataCopyCustom<T>(x2Local, gammaGm, numCol); // gammaLocal use x2Local
|
||||
event_t eventMTE2V2_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
|
||||
|
||||
// copy x out
|
||||
event_t eventVMTE3_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_0);
|
||||
DataCopyCustom<T>(xGm, x1Local, numCol);
|
||||
event_t eventMTE3V_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqxLocal, xFp32Local, xFp32Local, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqxLocal, sqxLocal, avgFactor, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumCustom(sqxLocal, sqxLocal, tmpLocal, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(sqxLocal, sqxLocal, epsilon, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(sqxLocal, sqxLocal, 1);
|
||||
Duplicate(tmpLocal, ONE, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(sqxLocal, tmpLocal, sqxLocal, 1);
|
||||
PipeBarrier<PIPE_V>();
|
||||
event_t eventVS_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(eventVS_BF16_0);
|
||||
WaitFlag<HardEvent::V_S>(eventVS_BF16_0);
|
||||
float rstdValue = sqxLocal.GetValue(0);
|
||||
event_t eventSV_BF16_0 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(eventSV_BF16_0);
|
||||
WaitFlag<HardEvent::S_V>(eventSV_BF16_0);
|
||||
// copyout rstd
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
event_t eventVMTE3_BF16_1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_1);
|
||||
DataCopyCustom<float>(rstdGm, sqxLocal, 1);
|
||||
event_t eventMTE3V2_BF16_0 = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>());
|
||||
SetFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
#endif
|
||||
|
||||
Muls(xFp32Local, xFp32Local, rstdValue, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V_BF16_0);
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xFp32Local, x1Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2V2_BF16_1);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
WaitFlag<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(eventMTE3V2_BF16_0);
|
||||
#endif
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
event_t eventVMTE2Beta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
|
||||
SetFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
WaitFlag<HardEvent::V_MTE2>(eventVMTE2Beta);
|
||||
DataCopyCustom<T>(x2Local, betaGm, numCol);
|
||||
event_t eventMTE2XBeta = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventMTE2XBeta);
|
||||
Cast(sqxLocal, x2Local, RoundMode::CAST_NONE, numCol);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(xFp32Local, xFp32Local, sqxLocal, numCol);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1Local, xFp32Local, RoundMode::CAST_RINT, numCol);
|
||||
event_t eventVMTE3_BF16_2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
SetFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventVMTE3_BF16_2);
|
||||
DataCopyCustom<T>(yGm, x1Local, numCol);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
|
||||
TBuf<TPosition::VECCALC> unitBuf;
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_SINGLE_N_H_
|
||||
395
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_split_d.h
Normal file
395
csrc/add_rms_norm_bias/op_kernel/add_rms_norm_bias_split_d.h
Normal file
@@ -0,0 +1,395 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file add_rms_norm_bias_split_d.h
|
||||
* \brief add rms norm bias split d file
|
||||
*/
|
||||
#ifndef ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
#define ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
#include "./rms_norm_base.h"
|
||||
|
||||
using namespace AscendC;
|
||||
using namespace RmsNorm;
|
||||
|
||||
template <typename T>
|
||||
class KernelAddRmsNormBiasSplitD {
|
||||
public:
|
||||
__aicore__ inline KernelAddRmsNormBiasSplitD(TPipe* pipe)
|
||||
{
|
||||
Ppipe = pipe;
|
||||
}
|
||||
__aicore__ inline void Init(
|
||||
GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR beta, GM_ADDR y, GM_ADDR rstd, GM_ADDR x, const AddRMSNormBiasTilingData* tiling)
|
||||
{
|
||||
ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!");
|
||||
this->numRow = tiling->num_row;
|
||||
this->numCol = tiling->num_col;
|
||||
this->blockFactor = tiling->block_factor;
|
||||
this->rowFactor = tiling->row_factor;
|
||||
this->ubFactor = tiling->ub_factor;
|
||||
this->epsilon = tiling->epsilon;
|
||||
this->avgFactor = (numCol != 0) ? (float)1.0 / numCol : 0;
|
||||
this->nullptrBeta = tiling->nullptr_beta;
|
||||
|
||||
blockIdx_ = GetBlockIdx();
|
||||
if (blockIdx_ < GetBlockNum() - 1) {
|
||||
this->rowWork = blockFactor;
|
||||
} else if (blockIdx_ == GetBlockNum() - 1) {
|
||||
this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor;
|
||||
} else {
|
||||
}
|
||||
// get start index for current core, core parallel
|
||||
x1Gm.SetGlobalBuffer((__gm__ T*)x1 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
x2Gm.SetGlobalBuffer((__gm__ T*)x2 + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
gammaGm.SetGlobalBuffer((__gm__ T*)gamma, numCol);
|
||||
if (!this->nullptrBeta) {
|
||||
betaGm.SetGlobalBuffer((__gm__ T*)beta, numCol);
|
||||
}
|
||||
yGm.SetGlobalBuffer((__gm__ T*)y + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
rstdGm.SetGlobalBuffer((__gm__ float*)rstd + blockIdx_ * blockFactor, blockFactor);
|
||||
xGm.SetGlobalBuffer((__gm__ T*)x + blockIdx_ * blockFactor * numCol, rowWork * numCol);
|
||||
|
||||
// pipe alloc memory to queue, the unit is Bytes.
|
||||
// We need 2 buffers here for both x1 and x2.
|
||||
Ppipe->InitBuffer(inQueueX, BUFFER_NUM, 2 * ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
if (!this->nullptrBeta) {
|
||||
Ppipe->InitBuffer(inQueueBeta, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
}
|
||||
Ppipe->InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T));
|
||||
Ppipe->InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float));
|
||||
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
Ppipe->InitBuffer(xFp32Buf, ubFactor * sizeof(float));
|
||||
}
|
||||
Ppipe->InitBuffer(sqxBuf, ubFactor * sizeof(float));
|
||||
Ppipe->InitBuffer(sumBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float));
|
||||
Ppipe->InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
uint32_t i_o_max = RmsNorm::CeilDiv(rowWork, rowFactor);
|
||||
uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor;
|
||||
uint32_t j_max = RmsNorm::CeilDiv(numCol, ubFactor);
|
||||
uint32_t col_tail = numCol - (j_max - 1) * ubFactor;
|
||||
for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) {
|
||||
SubProcess(i_o, rowFactor, j_max, col_tail);
|
||||
}
|
||||
SubProcess(i_o_max - 1, row_tail, j_max, col_tail);
|
||||
}
|
||||
|
||||
__aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, uint32_t j_max, uint32_t col_tail)
|
||||
{
|
||||
LocalTensor<float> sumLocal = sumBuf.Get<float>();
|
||||
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.AllocTensor<float>();
|
||||
Duplicate(rstdLocal, (float)0.0, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
for (uint32_t j = 0; j < j_max - 1; j++) {
|
||||
ComputeFormer(i_o, calc_row_num, j, rstdLocal, sumLocal, ubFactor);
|
||||
}
|
||||
// do tail
|
||||
ComputeFormer(i_o, calc_row_num, j_max - 1, rstdLocal, sumLocal, col_tail);
|
||||
ComputeRstd(rstdLocal, calc_row_num);
|
||||
|
||||
for (uint32_t j = 0; j < j_max - 1; j++) {
|
||||
ComputeLatter(i_o, calc_row_num, j, rstdLocal, ubFactor);
|
||||
}
|
||||
ComputeLatter(i_o, calc_row_num, j_max - 1, rstdLocal, col_tail);
|
||||
outQueueRstd.EnQue<float>(rstdLocal);
|
||||
CopyOutRstd(i_o, calc_row_num);
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInAndAdd(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> x1x2_in = inQueueX.AllocTensor<T>();
|
||||
LocalTensor<T> x1_in = x1x2_in[0];
|
||||
LocalTensor<T> x2_in = x1x2_in[ubFactor];
|
||||
DataCopyCustom<T>(x1_in, x1Gm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
DataCopyCustom<T>(x2_in, x2Gm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
inQueueX.EnQue(x1x2_in);
|
||||
LocalTensor<T> x1x2Local = inQueueX.DeQue<T>();
|
||||
|
||||
auto x1Local = x1x2Local[0];
|
||||
auto x2Local = x1x2Local[ubFactor];
|
||||
|
||||
LocalTensor<T> xLocal = outQueueY.AllocTensor<T>();
|
||||
|
||||
if constexpr (is_same<T, half>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
|
||||
Add(xLocal, x1Local, x2Local, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// x1+x2 saved in x1_fp32
|
||||
} else if constexpr (is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x1_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> x2_fp32 = x1x2Local.template ReinterpretCast<float>();
|
||||
|
||||
Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
|
||||
Add(x1_fp32, x1_fp32, x2_fp32, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// x1+x2 saved in x1_fp32
|
||||
} else {
|
||||
Add(x1Local, x1Local, x2Local, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Adds(xLocal, x1Local, (float)0.0, num);
|
||||
// x1+x2 saved in inQueueX
|
||||
}
|
||||
inQueueX.FreeTensor(x1x2Local);
|
||||
|
||||
// copy out to workspace && x_out
|
||||
outQueueY.EnQue(xLocal);
|
||||
auto x_out = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(xGm[i_idx * numCol + j_idx * ubFactor], x_out, num);
|
||||
outQueueY.FreeTensor(x_out);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFormer(
|
||||
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal,
|
||||
LocalTensor<float>& sumLocal, uint32_t num)
|
||||
{
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
ComputeSum(i_i, sumLocal, num);
|
||||
}
|
||||
BlockReduceSumFP32(sumLocal, sumLocal, calc_row_num * NUM_PER_BLK_FP32);
|
||||
Add(rstdLocal, rstdLocal, sumLocal, calc_row_num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeSum(uint32_t i_i_idx, LocalTensor<float>& sumLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, x_fp32, x_fp32, num);
|
||||
} else {
|
||||
LocalTensor<T> xLocal = inQueueX.AllocTensor<float>();
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(sqx, xLocal, xLocal, num);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(sqx, sqx, avgFactor, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
// 8 means 8 fp32 pre block
|
||||
ReduceSumFP32ToBlock(sumLocal[i_i_idx * 8], sqx, reduce_buf_local, num);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRstd(LocalTensor<float> rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> reduce_buf_local = reduceFp32Buf.Get<float>();
|
||||
Adds(rstdLocal, rstdLocal, epsilon, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Sqrt(rstdLocal, rstdLocal, num);
|
||||
Duplicate(reduce_buf_local, ONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Div(rstdLocal, reduce_buf_local, rstdLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeLatter(
|
||||
uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
CopyInGammaBeta(j_idx, num);
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.DeQue<T>();
|
||||
LocalTensor<T> betaLocal;
|
||||
if (!this->nullptrBeta) {
|
||||
betaLocal = inQueueBeta.DeQue<T>();
|
||||
}
|
||||
for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) {
|
||||
CopyInX(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
ComputeY(i_i, gammaLocal, betaLocal, rstdLocal, num);
|
||||
CopyOutY(i_o_idx * rowFactor + i_i, j_idx, num);
|
||||
}
|
||||
inQueueGamma.FreeTensor(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
inQueueBeta.FreeTensor(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInGammaBeta(uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> gammaLocal = inQueueGamma.AllocTensor<T>();
|
||||
DataCopyCustom<T>(gammaLocal, gammaGm[j_idx * ubFactor], num);
|
||||
inQueueGamma.EnQue(gammaLocal);
|
||||
if (!this->nullptrBeta) {
|
||||
LocalTensor<T> betaLocal = inQueueBeta.AllocTensor<T>();
|
||||
DataCopyCustom<T>(betaLocal, betaGm[j_idx * ubFactor], num);
|
||||
inQueueBeta.EnQue(betaLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInX(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> xLocal = inQueueX.AllocTensor<T>();
|
||||
DataCopyCustom<T>(xLocal, xGm[i_idx * numCol + j_idx * ubFactor], num);
|
||||
inQueueX.EnQue<T>(xLocal);
|
||||
if constexpr (is_same<T, half>::value || is_same<T, bfloat16_t>::value) {
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<T> xLocal = inQueueX.DeQue<T>();
|
||||
Cast(x_fp32, xLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<half>& gammaLocal, LocalTensor<half>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<half> yLocal = outQueueY.AllocTensor<half>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Add(yLocal, betaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
outQueueY.EnQue<half>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<float>& gammaLocal, LocalTensor<float>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> xLocal = inQueueX.DeQue<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
LocalTensor<float> yLocal = outQueueY.AllocTensor<float>();
|
||||
Muls(yLocal, xLocal, rstdValue, num);
|
||||
inQueueX.FreeTensor(xLocal);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(yLocal, gammaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Add(yLocal, betaLocal, yLocal, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
outQueueY.EnQue<float>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeY(
|
||||
uint32_t i_i_idx, LocalTensor<bfloat16_t>& gammaLocal, LocalTensor<bfloat16_t>& betaLocal, LocalTensor<float>& rstdLocal, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> x_fp32 = xFp32Buf.Get<float>();
|
||||
LocalTensor<float> sqx = sqxBuf.Get<float>();
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
float rstdValue = rstdLocal.GetValue(i_i_idx);
|
||||
event_t event_s_v = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
|
||||
SetFlag<HardEvent::S_V>(event_s_v);
|
||||
WaitFlag<HardEvent::S_V>(event_s_v);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Muls(x_fp32, x_fp32, rstdValue, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
LocalTensor<bfloat16_t> yLocal = outQueueY.AllocTensor<bfloat16_t>();
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(x_fp32, yLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(sqx, gammaLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Mul(x_fp32, x_fp32, sqx, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (!this->nullptrBeta) {
|
||||
Cast(sqx, betaLocal, RoundMode::CAST_NONE, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Add(x_fp32, x_fp32, sqx, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num);
|
||||
PipeBarrier<PIPE_V>();
|
||||
outQueueY.EnQue<bfloat16_t>(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutY(uint32_t i_idx, uint32_t j_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<T> yLocal = outQueueY.DeQue<T>();
|
||||
DataCopyCustom<T>(yGm[i_idx * numCol + j_idx * ubFactor], yLocal, num);
|
||||
outQueueY.FreeTensor(yLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOutRstd(uint32_t i_o_idx, uint32_t num)
|
||||
{
|
||||
LocalTensor<float> rstdLocal = outQueueRstd.DeQue<float>();
|
||||
#if __CCE_AICORE__ == 220 || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyCustom<float>(rstdGm[i_o_idx * rowFactor], rstdLocal, num);
|
||||
#endif
|
||||
outQueueRstd.FreeTensor(rstdLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
TPipe* Ppipe = nullptr;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueX;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueGamma;
|
||||
TQue<QuePosition::VECIN, BUFFER_NUM> inQueueBeta;
|
||||
// create queues for output, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueY;
|
||||
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueueRstd;
|
||||
TBuf<TPosition::VECCALC> xFp32Buf;
|
||||
TBuf<TPosition::VECCALC> sqxBuf;
|
||||
TBuf<TPosition::VECCALC> sumBuf;
|
||||
TBuf<TPosition::VECCALC> reduceFp32Buf;
|
||||
|
||||
GlobalTensor<T> x1Gm;
|
||||
GlobalTensor<T> x2Gm;
|
||||
GlobalTensor<T> gammaGm;
|
||||
GlobalTensor<T> betaGm;
|
||||
GlobalTensor<T> yGm;
|
||||
GlobalTensor<float> rstdGm;
|
||||
GlobalTensor<T> xGm;
|
||||
|
||||
uint32_t numRow;
|
||||
uint32_t numCol;
|
||||
uint32_t blockFactor; // number of calculations rows on each core
|
||||
uint32_t rowFactor;
|
||||
uint32_t ubFactor;
|
||||
float epsilon;
|
||||
float avgFactor;
|
||||
int32_t blockIdx_;
|
||||
uint32_t rowWork = 1;
|
||||
uint32_t nullptrBeta = 0;
|
||||
|
||||
int tempbufNum;
|
||||
};
|
||||
#endif // _ADD_RMS_NORM_BIAS_SPLIT_D_H_
|
||||
179
csrc/add_rms_norm_bias/op_kernel/reduce_common.h
Normal file
179
csrc/add_rms_norm_bias/op_kernel/reduce_common.h
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
/*!
|
||||
* \file reduce_common.h
|
||||
*/
|
||||
#ifndef REDUCE_COMMON_H_RMS_NORM
|
||||
#define REDUCE_COMMON_H_RMS_NORM
|
||||
#include "kernel_operator.h"
|
||||
using namespace AscendC;
|
||||
|
||||
constexpr uint32_t MAX_REP_NUM = 255;
|
||||
constexpr uint32_t ELEM_PER_REP_FP32 = 64;
|
||||
constexpr uint32_t ELEM_PER_BLK_FP32 = 8;
|
||||
constexpr float ZERO = 0;
|
||||
constexpr int32_t HALf_INTERVAL = 2;
|
||||
constexpr int32_t INDEX_TWO = 2;
|
||||
constexpr int32_t INDEX_FOUR = 4;
|
||||
constexpr int32_t INDEX_EIGHT = 8;
|
||||
constexpr int32_t INDEX_SIXTEEN = 16;
|
||||
|
||||
__aicore__ inline void ReduceSumForSmallReduceDimPreRepeat(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t elemNum, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
|
||||
const uint8_t repStride)
|
||||
{
|
||||
uint32_t elemIndex = 0;
|
||||
for (; elemIndex + ELEM_PER_REP_FP32 <= numLastDim; elemIndex += ELEM_PER_REP_FP32) {
|
||||
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, elemNum, repeat,
|
||||
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(tmpLocal, srcLocal[elemIndex], tmpLocal, tailCount, repeat,
|
||||
{1, 1, 1, ELEM_PER_BLK_FP32, repStride, ELEM_PER_BLK_FP32});
|
||||
}
|
||||
PipeBarrier<PIPE_V>();
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32); // set mask = 64
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if ASCEND_IS_AIV {
|
||||
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(dstLocal, tmpLocal, elemNum, repeat, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* reduce dim form (N, D) to (N, 1)
|
||||
* this reduce sum is for small reduce dim.
|
||||
*/
|
||||
__aicore__ inline void ReduceSumForSmallReduceDim(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t numLastDimAligned, const uint32_t numLastDim, const uint32_t tailCount, const uint32_t repeat,
|
||||
const uint8_t repStride)
|
||||
{
|
||||
uint32_t repeatTimes = repeat / MAX_REP_NUM;
|
||||
if (repeatTimes == 0) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal, srcLocal, tmpLocal, ELEM_PER_REP_FP32, numLastDim, tailCount, repeat, repStride);
|
||||
} else {
|
||||
uint32_t repTailNum = repeat % MAX_REP_NUM;
|
||||
uint32_t repIndex = 0;
|
||||
uint32_t repElem;
|
||||
for (; repIndex + MAX_REP_NUM <= repeat; repIndex += MAX_REP_NUM) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
|
||||
ELEM_PER_REP_FP32, numLastDim, tailCount, MAX_REP_NUM, repStride);
|
||||
}
|
||||
if (repTailNum != 0) {
|
||||
ReduceSumForSmallReduceDimPreRepeat(
|
||||
dstLocal[repIndex], srcLocal[repIndex * numLastDimAligned], tmpLocal[repIndex * ELEM_PER_REP_FP32],
|
||||
ELEM_PER_REP_FP32, numLastDim, tailCount, repTailNum, repStride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* reduce dim form (N, D) to (N, 1)
|
||||
* this reduce sum is for small reduce dim, require D < 255 * 8.
|
||||
* size of tmpLocal: (N, 64)
|
||||
*/
|
||||
__aicore__ inline void ReduceSumMultiN(
|
||||
const LocalTensor<float>& dstLocal, const LocalTensor<float>& srcLocal, const LocalTensor<float>& tmpLocal,
|
||||
const uint32_t numRow, const uint32_t numCol, const uint32_t numColAlign)
|
||||
{
|
||||
const uint32_t tailCount = numCol % ELEM_PER_REP_FP32;
|
||||
const uint32_t repeat = numRow;
|
||||
const uint8_t repStride = numColAlign / ELEM_PER_BLK_FP32;
|
||||
Duplicate(tmpLocal, ZERO, numRow * ELEM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
ReduceSumForSmallReduceDim(dstLocal, srcLocal, tmpLocal, numColAlign, numCol, tailCount, repeat, repStride);
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t findPowerTwo(int32_t n)
|
||||
{
|
||||
// find max power of 2 no more than n (32 bit)
|
||||
n |= n >> 1; // Set the first digit of n's binary to 1
|
||||
n |= n >> INDEX_TWO;
|
||||
n |= n >> INDEX_FOUR;
|
||||
n |= n >> INDEX_EIGHT;
|
||||
n |= n >> INDEX_SIXTEEN;
|
||||
return (n + 1) >> 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumHalfInterval(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
if (likely(count > ELEM_PER_REP_FP32)) {
|
||||
int32_t bodyCount = findPowerTwo(count);
|
||||
int32_t tailCount = count - bodyCount;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > ELEM_PER_REP_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
|
||||
} else {
|
||||
AscendCUtils::SetMask<float>(count);
|
||||
}
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(dst_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline float ReduceSumHalfInterval(const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
if (likely(count > ELEM_PER_REP_FP32)) {
|
||||
int32_t bodyCount = findPowerTwo(count);
|
||||
int32_t tailCount = count - bodyCount;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > ELEM_PER_REP_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
AscendCUtils::SetMask<float>(ELEM_PER_REP_FP32);
|
||||
} else {
|
||||
AscendCUtils::SetMask<float>(count);
|
||||
}
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 0, 1, 0);
|
||||
}
|
||||
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003
|
||||
WholeReduceSum(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, ELEM_PER_BLK_FP32);
|
||||
#else
|
||||
WholeReduceSum<float, false>(src_local, src_local, ELEM_PER_REP_FP32, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
event_t event_v_s = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
|
||||
SetFlag<HardEvent::V_S>(event_v_s);
|
||||
WaitFlag<HardEvent::V_S>(event_v_s);
|
||||
return src_local.GetValue(0);
|
||||
}
|
||||
#endif // _REDUCE_COMMON_H_
|
||||
316
csrc/add_rms_norm_bias/op_kernel/rms_norm_base.h
Normal file
316
csrc/add_rms_norm_bias/op_kernel/rms_norm_base.h
Normal file
@@ -0,0 +1,316 @@
|
||||
/**
|
||||
* This program is free software, you can redistribute it and/or modify.
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef RMS_NORM_BASE_H_
|
||||
#define RMS_NORM_BASE_H_
|
||||
#include "kernel_operator.h"
|
||||
#include "reduce_common.h"
|
||||
|
||||
namespace RmsNorm {
|
||||
using namespace AscendC;
|
||||
|
||||
|
||||
/**
|
||||
* Get the block size of unified buffer in bytes
|
||||
*/
|
||||
__aicore__ inline constexpr uint32_t GetUbBlockSize()
|
||||
{
|
||||
return 32U;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of vector registers in bytes
|
||||
*/
|
||||
__aicore__ inline constexpr uint32_t GetVRegSize()
|
||||
{
|
||||
#if __CCE_AICORE__ == 310
|
||||
return AscendC::VECTOR_REG_WIDTH;
|
||||
#else
|
||||
return 256U;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ != 220 && __CCE_AICORE__ != 310
|
||||
#define bfloat16_t int16_t
|
||||
#endif
|
||||
constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue
|
||||
constexpr int32_t DOUBLE_BUFFER_NUM = 2;
|
||||
constexpr int32_t UNROLL_NUM = 2;
|
||||
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
|
||||
constexpr int32_t NUM_PER_BLK_FP32 = 8;
|
||||
constexpr int32_t FLOAT_BTYPE_SIZE = 4;
|
||||
constexpr int32_t NUM_PER_BLK_FP16 = 16;
|
||||
constexpr int32_t CONTINUE_STRIDE = 8;
|
||||
constexpr int32_t BLOCK_SIZE = 32;
|
||||
constexpr uint32_t ONCE_VECTOR_SIZE = 256;
|
||||
constexpr float MINUS_HALF = -0.5f;
|
||||
constexpr uint32_t ZERO_UINT = 0;
|
||||
constexpr uint32_t ONE_UINT = 1;
|
||||
constexpr uint32_t TWO_UINT = 2;
|
||||
constexpr uint32_t THREE_UINT = 3;
|
||||
constexpr float ONE = 1;
|
||||
constexpr int32_t SECOND_LOOP = 2;
|
||||
constexpr int32_t HALf_INTERVAL = 2;
|
||||
constexpr int32_t MAX_REAPEAT = 255;
|
||||
constexpr int32_t DIM_NUM = 2;
|
||||
constexpr int32_t NDDMA_DIM = 5;
|
||||
|
||||
constexpr uint32_t V_LENGTH = GetVRegSize() / sizeof(float);
|
||||
constexpr uint64_t ALIGN_512_FACTOR = 512;
|
||||
constexpr uint64_t ALIGN_32_FACTOR = 32;
|
||||
constexpr int32_t CONST_FACTOR_2 = 2;
|
||||
constexpr uint32_t SUM_COUNT = 2;
|
||||
constexpr int32_t MOV_2 = 2;
|
||||
constexpr int32_t MOV_4 = 4;
|
||||
constexpr int32_t MOV_8 = 8;
|
||||
constexpr int32_t MOV_16 = 16;
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T CeilDiv(T x, T y)
|
||||
{
|
||||
return y == 0 ? x : (x + y - 1) / y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(T left, T right)
|
||||
{
|
||||
return (left < right ? left : right);
|
||||
}
|
||||
|
||||
template <typename Tp, Tp v>
|
||||
struct integral_constant {
|
||||
static constexpr Tp value = v;
|
||||
};
|
||||
using true_type = integral_constant<bool, true>;
|
||||
using false_type = integral_constant<bool, false>;
|
||||
template <typename, typename>
|
||||
struct is_same : public false_type {};
|
||||
template <typename Tp>
|
||||
struct is_same<Tp, Tp> : public true_type {};
|
||||
|
||||
template <typename T, typename T_GAMMA>
|
||||
class KernelRmsNormBase {
|
||||
#define IS_X_FP32 (is_same<T, float>::value)
|
||||
#define IS_GAMMA_FP32 (is_same<T_GAMMA, float>::value)
|
||||
#define IS_MIX_DTYPE ((!IS_X_FP32) && IS_GAMMA_FP32)
|
||||
};
|
||||
|
||||
__aicore__ inline int32_t findPowerTwo(int32_t n)
|
||||
{
|
||||
// find max power of 2 no more than n (32 bit)
|
||||
n |= n >> 1; // Set the first digit of n's binary to 1
|
||||
n |= n >> MOV_2;
|
||||
n |= n >> MOV_4;
|
||||
n |= n >> MOV_8;
|
||||
n |= n >> MOV_16;
|
||||
return (n + 1) >> 1;
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumHalfIntervalToRepeat(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count, int32_t left)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
if (likely(count > NUM_PER_BLK_FP32)) {
|
||||
int32_t bodyCount = count - left;
|
||||
int32_t tailCount = left;
|
||||
if (tailCount > 0) {
|
||||
Add(src_local, src_local, src_local[bodyCount], tailCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
while (bodyCount > SECOND_LOOP * NUM_PER_BLK_FP32) {
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(src_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
bodyCount = bodyCount / HALf_INTERVAL;
|
||||
Add(dst_local, src_local, src_local[bodyCount], bodyCount);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumFP32(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
uint64_t mask = NUM_PER_REP_FP32;
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
BinaryRepeatParams repeatParams;
|
||||
repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
|
||||
repeatParams.src0BlkStride = 1;
|
||||
repeatParams.src1RepStride = 0;
|
||||
repeatParams.src1BlkStride = 1;
|
||||
repeatParams.dstRepStride = 0;
|
||||
repeatParams.dstBlkStride = 1;
|
||||
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
if (g_coreType == AIV) {
|
||||
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 0, 1, 0);
|
||||
}
|
||||
#elif !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
WholeReduceSum<float, false>(dst_local, work_local, MASK_PLACEHOLDER, 1, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
#endif
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ReduceSumCustom(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
ReduceSumFP32(dst_local, src_local, work_local, count);
|
||||
}
|
||||
__aicore__ inline void ReduceSumFP32ToBlock(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, const LocalTensor<float>& work_local,
|
||||
int32_t count)
|
||||
{
|
||||
// count need smaller than 255 repeat
|
||||
uint64_t mask = NUM_PER_REP_FP32;
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
||||
BinaryRepeatParams repeatParams;
|
||||
repeatParams.src0RepStride = ONCE_VECTOR_SIZE / BLOCK_SIZE;
|
||||
repeatParams.src0BlkStride = 1;
|
||||
repeatParams.src1RepStride = 0;
|
||||
repeatParams.src1BlkStride = 1;
|
||||
repeatParams.dstRepStride = 0;
|
||||
repeatParams.dstBlkStride = 1;
|
||||
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
||||
PipeBarrier<PIPE_V>();
|
||||
if (likely(repeatTimes > 0)) {
|
||||
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (unlikely(tailCount != 0)) {
|
||||
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void BlockReduceSumFP32(
|
||||
const LocalTensor<float>& dst_local, const LocalTensor<float>& src_local, int32_t count)
|
||||
{
|
||||
// count need multiple of 8
|
||||
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
||||
int32_t tailCount = count % NUM_PER_REP_FP32;
|
||||
int32_t dstAddr = repeatTimes * 8;
|
||||
int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32;
|
||||
if (likely(repeatTimes > 0)) {
|
||||
BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
if (tailCount != 0) {
|
||||
BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, DEFAULT_REPEAT_STRIDE);
|
||||
PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename R>
|
||||
__aicore__ inline void DataCopyCustom(const U& dstTensor, const R& srcTensor, const uint32_t count)
|
||||
{
|
||||
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3003)
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = count * sizeof(T);
|
||||
copyParams.blockCount = 1;
|
||||
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
||||
DataCopyPadParams padParams;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
|
||||
} else {
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams);
|
||||
}
|
||||
#else
|
||||
// only support count greater than 32byte
|
||||
int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T);
|
||||
if (count % numPerBlock == 0) {
|
||||
DataCopy(dstTensor, srcTensor, count);
|
||||
} else {
|
||||
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
||||
int32_t num = AlignUp(count, numPerBlock);
|
||||
DataCopy(dstTensor, srcTensor, num);
|
||||
} else {
|
||||
if (count < numPerBlock) {
|
||||
DataCopy(dstTensor, srcTensor, numPerBlock);
|
||||
} else {
|
||||
int32_t num = count / numPerBlock * numPerBlock;
|
||||
DataCopy(dstTensor, srcTensor, num);
|
||||
SetFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
WaitFlag<HardEvent::MTE3_S>(EVENT_ID0);
|
||||
for (int32_t i = 0; i < numPerBlock; i++) {
|
||||
T tensorValue = srcTensor.GetValue(count - numPerBlock + i);
|
||||
srcTensor.SetValue(i, tensorValue);
|
||||
}
|
||||
SetFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
WaitFlag<HardEvent::S_MTE3>(EVENT_ID0);
|
||||
DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void DataCopyCustom(
|
||||
const LocalTensor<T>& dstTensor, const GlobalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
|
||||
{
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = numCol * sizeof(T);
|
||||
copyParams.blockCount = numRow;
|
||||
DataCopyPadParams padParams;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void DataCopyCustom(
|
||||
const GlobalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, const uint32_t numRow, const uint32_t numCol)
|
||||
{
|
||||
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 220
|
||||
DataCopyParams copyParams;
|
||||
copyParams.blockLen = numCol * sizeof(T);
|
||||
copyParams.blockCount = numRow;
|
||||
DataCopyPad(dstTensor, srcTensor, copyParams);
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline void RoundFloat2Int8(LocalTensor<int8_t>& dstTensor, LocalTensor<float>& srcTensor, int32_t size)
|
||||
{
|
||||
Cast(srcTensor.ReinterpretCast<int32_t>(), srcTensor, RoundMode::CAST_RINT, size);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SetDeqScale((half)1.000000e+00f);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(srcTensor.ReinterpretCast<half>(), srcTensor.ReinterpretCast<int32_t>(), RoundMode::CAST_NONE, size);
|
||||
PipeBarrier<PIPE_V>();
|
||||
Cast(dstTensor, srcTensor.ReinterpretCast<half>(), RoundMode::CAST_TRUNC, size);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t ROUND_UP(uint32_t x, uint32_t block_number)
|
||||
{
|
||||
if (block_number > 0) {
|
||||
return (x + block_number - 1) / block_number * block_number;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace RmsNorm
|
||||
#endif // RMS_NORM_BASE_H_
|
||||
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -79,6 +79,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"notify_dispatch"
|
||||
"moe_init_routing_custom"
|
||||
"moe_gating_top_k"
|
||||
"add_rms_norm_bias"
|
||||
)
|
||||
CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}")
|
||||
SOC_ARG="ascend910_93"
|
||||
|
||||
@@ -1288,6 +1288,38 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias(
|
||||
const at::Tensor& x1,
|
||||
const at::Tensor& x2,
|
||||
const at::Tensor& gamma,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
double epsilon)
|
||||
{
|
||||
int64_t dim_x = x1.dim();
|
||||
int64_t dim_gamma = gamma.dim();
|
||||
int64_t diff = dim_x - dim_gamma;
|
||||
std::vector<int64_t> new_shape;
|
||||
at::Tensor rstd;
|
||||
|
||||
if (diff > 0) {
|
||||
new_shape.reserve(dim_x);
|
||||
auto x1_sizes = x1.sizes();
|
||||
for (int64_t i = 0; i < diff; ++i) {
|
||||
new_shape.push_back(x1_sizes[i]);
|
||||
}
|
||||
for (int64_t i = 0; i < dim_gamma; ++i) {
|
||||
new_shape.push_back(1);
|
||||
}
|
||||
} else {
|
||||
new_shape.assign(dim_x, 1);
|
||||
}
|
||||
rstd = at::empty(new_shape, x1.options().dtype(at::kFloat));
|
||||
at::Tensor y = at::empty(x1.sizes(), x1.options());
|
||||
at::Tensor x = at::empty(x1.sizes(), x1.options());
|
||||
EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x);
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -1453,4 +1485,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"-> (Tensor y ,Tensor expert_idx, Tensor out)"
|
||||
);
|
||||
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
|
||||
|
||||
ops.def(
|
||||
"npu_add_rms_norm_bias(Tensor x1, "
|
||||
"Tensor x2, "
|
||||
"Tensor gamma, "
|
||||
"Tensor? beta=None, "
|
||||
"float epsilon=1e-6)"
|
||||
"-> (Tensor y ,Tensor rstd, Tensor x)"
|
||||
);
|
||||
ops.impl("npu_add_rms_norm_bias", torch::kPrivateUse1, &vllm_ascend::npu_add_rms_norm_bias);
|
||||
}
|
||||
|
||||
@@ -403,6 +403,37 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
|
||||
const at::Tensor& x1,
|
||||
const at::Tensor& x2,
|
||||
const at::Tensor& gamma,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
double epsilon)
|
||||
{
|
||||
int64_t dim_x = x1.dim();
|
||||
int64_t dim_gamma = gamma.dim();
|
||||
int64_t diff = dim_x - dim_gamma;
|
||||
c10::SymDimVector new_shape;
|
||||
at::Tensor rstd;
|
||||
|
||||
if (diff > 0) {
|
||||
new_shape.reserve(dim_x);
|
||||
auto x1_sizes = x1.sym_sizes();
|
||||
for (int64_t i = 0; i < diff; ++i) {
|
||||
new_shape.push_back(x1_sizes[i]);
|
||||
}
|
||||
for (int64_t i = 0; i < dim_gamma; ++i) {
|
||||
new_shape.push_back(c10::SymInt(1));
|
||||
}
|
||||
} else {
|
||||
new_shape.assign(dim_x, c10::SymInt(1));
|
||||
}
|
||||
rstd = at::empty_symint(new_shape, x1.options().dtype(at::kFloat));
|
||||
at::Tensor y = at::empty_symint(x1.sym_sizes(), x1.options());
|
||||
at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options());
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||
}
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -441,5 +472,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta);
|
||||
// Moe_gating_top_k
|
||||
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
|
||||
// Add_Rms_Norm_Bias
|
||||
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,9 @@ def check_outputs_equal(
|
||||
# The text and token outputs should exactly match
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
f"\n{name_1}:\t{output_str_1!r}"
|
||||
f"\n{name_0}:\t{output_ids_0!r}"
|
||||
f"\n{name_1}:\t{output_ids_1!r}")
|
||||
|
||||
assert output_str_0 == output_str_1, fail_msg
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
seed = 45
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def npu_add_rms_norm_bias_golden(input_x1,
|
||||
input_x2,
|
||||
input_gamma,
|
||||
input_beta,
|
||||
kernelType,
|
||||
epsilon=0.000001):
|
||||
ori_x_shape = input_x1.shape
|
||||
ori_gamma_shape = input_gamma.shape
|
||||
xlength = len(ori_x_shape)
|
||||
gammaLength = len(ori_gamma_shape)
|
||||
torchType32 = torch.float32
|
||||
rstdShape = []
|
||||
rstdSize = 1
|
||||
for i in range(xlength):
|
||||
if i < (xlength - gammaLength):
|
||||
rstdShape.append(ori_x_shape[i])
|
||||
rstdSize = rstdSize * ori_x_shape[i]
|
||||
else:
|
||||
rstdShape.append(1)
|
||||
|
||||
n = xlength - gammaLength
|
||||
gammaSize = np.multiply.reduce(np.array(ori_gamma_shape))
|
||||
input_gamma = input_gamma.reshape(gammaSize)
|
||||
input_beta = input_beta.reshape(gammaSize)
|
||||
x1_shape = ori_x_shape[0:n] + input_gamma.shape
|
||||
input_x1 = input_x1.reshape(x1_shape)
|
||||
input_x2 = input_x2.reshape(x1_shape)
|
||||
|
||||
if kernelType == 1:
|
||||
oriType = torch.float16
|
||||
xOut = (input_x1.to(oriType) + input_x2.to(oriType))
|
||||
elif kernelType == 2:
|
||||
oriType = torch.bfloat16
|
||||
x_fp32 = (input_x1.to(torchType32) + input_x2.to(torchType32))
|
||||
xOut = x_fp32.to(oriType)
|
||||
else:
|
||||
oriType = torch.float32
|
||||
xOut = (input_x1.to(torchType32) + input_x2.to(torchType32))
|
||||
x_fp32 = xOut.to(torchType32)
|
||||
avgFactor = 1 / gammaSize
|
||||
x_2 = torch.pow(x_fp32, 2)
|
||||
x_2_mean = x_2 * avgFactor
|
||||
tmp_sum = torch.sum(x_2_mean, axis=-1, keepdims=True)
|
||||
tmp_add_eps = tmp_sum + epsilon
|
||||
std = torch.sqrt(tmp_add_eps)
|
||||
rstd = 1 / std
|
||||
result_mid = x_fp32 * rstd
|
||||
if kernelType == 1:
|
||||
result_mid_ori = result_mid.to(oriType)
|
||||
y_array = result_mid_ori * input_gamma.to(oriType)
|
||||
y_array = y_array + input_beta.to(oriType)
|
||||
elif kernelType == 2:
|
||||
result_mid_ori = result_mid.to(oriType)
|
||||
y_array = result_mid_ori.to(torchType32) * input_gamma.to(torchType32)
|
||||
y_array = y_array + input_beta.to(torchType32)
|
||||
else:
|
||||
y_array = result_mid.to(torchType32) * input_gamma.to(torchType32)
|
||||
y_array = y_array + input_beta.to(torchType32)
|
||||
rstdOut = rstd.reshape(rstdShape).to(torchType32)
|
||||
yOut = y_array.reshape(ori_x_shape).to(oriType)
|
||||
xOut = x_fp32.reshape(ori_x_shape).to(oriType)
|
||||
return yOut, rstdOut, xOut
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'row',
|
||||
[1, 16, 64, 77, 128, 255, 1000],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'col',
|
||||
[
|
||||
8,
|
||||
16,
|
||||
128,
|
||||
3000,
|
||||
7168,
|
||||
15000,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, atol, rtol, kernelType",
|
||||
[
|
||||
(torch.float16, 0.0010986328125, 0.0010986328125, 1),
|
||||
(torch.bfloat16, 0.0079345703125, 0.0079345703125, 2),
|
||||
(torch.float32, 0.000244140625, 0.000244140625, 3),
|
||||
],
|
||||
)
|
||||
def test_quant_fpx_linear(row: int, col: int, dtype, atol, rtol, kernelType):
|
||||
shape_x = [row, col]
|
||||
shape_gamma = [col]
|
||||
|
||||
dataType = dtype
|
||||
|
||||
input_x1 = np.random.uniform(1, 10, size=tuple(shape_x)).astype(np.float32)
|
||||
input_x1_tensor = torch.tensor(input_x1).type(dataType)
|
||||
|
||||
input_x2 = np.random.uniform(1, 10, size=tuple(shape_x)).astype(np.float32)
|
||||
input_x2_tensor = torch.tensor(input_x2).type(dataType)
|
||||
|
||||
input_gamma = np.random.uniform(1, 10,
|
||||
size=tuple(shape_gamma)).astype(np.float32)
|
||||
input_gamma_tensor = torch.tensor(input_gamma).type(dataType)
|
||||
|
||||
input_beta = np.random.uniform(1, 10,
|
||||
size=tuple(shape_gamma)).astype(np.float32)
|
||||
grad_bias = torch.tensor(input_beta).type(dataType)
|
||||
y, rstd, x = torch.ops._C_ascend.npu_add_rms_norm_bias(input_x1_tensor.npu(),
|
||||
input_x2_tensor.npu(),
|
||||
input_gamma_tensor.npu(),
|
||||
grad_bias.npu(), 1e-6)
|
||||
|
||||
y = y.cpu()
|
||||
rstd = rstd.cpu()
|
||||
x = x.cpu()
|
||||
|
||||
y1, rstd1, x1 = npu_add_rms_norm_bias_golden(input_x1_tensor,
|
||||
input_x2_tensor,
|
||||
input_gamma_tensor,
|
||||
grad_bias,
|
||||
kernelType,
|
||||
epsilon=0.000001)
|
||||
|
||||
a = y1 > 1
|
||||
a1 = y1 <= 1
|
||||
b = rstd1 > 1
|
||||
b1 = rstd1 <= 1
|
||||
c = x1 > 1
|
||||
c1 = x1 <= 1
|
||||
torch.testing.assert_close(y * a, y1 * a, atol=atol, rtol=100)
|
||||
torch.testing.assert_close(y * a1, y1 * a1, rtol=rtol, atol=100)
|
||||
torch.testing.assert_close(rstd * b, rstd1 * b, atol=atol, rtol=100)
|
||||
torch.testing.assert_close(rstd * b1, rstd1 * b1, rtol=rtol, atol=100)
|
||||
torch.testing.assert_close(x * c, x1 * c, atol=atol, rtol=100)
|
||||
torch.testing.assert_close(x * c1, x1 * c1, rtol=rtol, atol=100)
|
||||
@@ -420,7 +420,7 @@ def test_llama_qwen_eagle_acceptance(
|
||||
]
|
||||
golden = BASELINES[method]
|
||||
|
||||
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
||||
match = all(abs(a - b) < 0.08 for a, b in zip(acceptance_per_pos, golden))
|
||||
if not match:
|
||||
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
||||
print(f"golden: {golden}")
|
||||
|
||||
@@ -57,9 +57,9 @@ CASE_DS_FULL_DECODE_ONLY = LLMTestCase(
|
||||
quantization="ascend",
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
'\n\nSelect an assignment template',
|
||||
'\n\nSelect an assignment template',
|
||||
'\n\nSelect an assignment template'
|
||||
"\n\nSelect an assignment template",
|
||||
"\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use",
|
||||
"\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations"
|
||||
])
|
||||
|
||||
CASE_QWEN_EX = LLMTestCase(
|
||||
@@ -75,9 +75,9 @@ CASE_DS_EX = LLMTestCase(model="vllm-ascend/DeepSeek-V2-Lite-W8A8",
|
||||
quantization="ascend",
|
||||
prompts=PROMPTS_LONG,
|
||||
golden_answers=[
|
||||
'\n\nYour answer seems reasonable. Find out if you\'re right!\n\nSign up to access problem solutions.\n\nThat seems reasonable. Find out',
|
||||
'\n\nI\'m not sure how to approach this problem. I\'m not sure if I should use the law of total probability or if I should use',
|
||||
'\n\nLet $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$and $'
|
||||
"\n\nSelect an assignment template",
|
||||
"\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use",
|
||||
"\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations"
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("cur_case", [CASE_QWEN_ACLGRAPH, CASE_DS_ACLGRAPH])
|
||||
|
||||
@@ -28,8 +28,8 @@ def test_qwen3_w8a8_quant():
|
||||
]
|
||||
vllm_target_outputs = [([
|
||||
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
||||
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
|
||||
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
|
||||
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 369, 3460
|
||||
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed for large'
|
||||
)]
|
||||
|
||||
with VllmRunner(
|
||||
|
||||
@@ -6,6 +6,8 @@ from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,6 +22,13 @@ def mock_rms_norm(x, weight, eps):
|
||||
def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
def mock_add_rms_norm_bias(x, residual, weight, bias, eps):
|
||||
if bias is None:
|
||||
return 2 * x, None, 2 * residual
|
||||
else:
|
||||
return 2 * x + bias, None, 2 * residual
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def default_vllm_config():
|
||||
@@ -35,7 +44,8 @@ def default_vllm_config():
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
@patch("torch.ops._C_ascend.npu_add_rms_norm_bias", side_effect=mock_add_rms_norm_bias)
|
||||
def test_RMSNorm_forward(mock_add_rms_norm_bias, mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
dummy_tensor, default_vllm_config):
|
||||
|
||||
with patch("vllm_ascend.utils.get_ascend_device_type",
|
||||
@@ -56,7 +66,7 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
mock_add_rms_norm_bias.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
|
||||
@@ -23,6 +23,9 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated
|
||||
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
|
||||
def __init__(
|
||||
@@ -57,6 +60,9 @@ class AscendRMSNorm(RMSNorm):
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
elif enable_custom_op():
|
||||
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, self.weight, self.bias, self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
@@ -88,6 +94,10 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
elif enable_custom_op():
|
||||
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, 1.0 + self.weight, None,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
||||
|
||||
Reference in New Issue
Block a user