add mla_preprocess kernel (#3226)
### What this PR does / why we need it? - Adds the `mla_preprocess` custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs. - Wires the new kernel into the C++ extension pipeline so vLLM can invoke it directly, cutting Python-side tensor shuffling and memory copies that previously bottlenecked MLA compilation paths. ### Does this PR introduce any user-facing change? - No. The change only introduces a low-level kernel; public APIs and inference behavior remain unchanged. ### How was this patch tested? - Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
25
csrc/mla_preprocess/op_kernel/kernel/common.h
Normal file
25
csrc/mla_preprocess/op_kernel/kernel/common.h
Normal file
@@ -0,0 +1,25 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_COMMON_H
|
||||
#define INCLUDE_COMMON_H
|
||||
|
||||
#define CONST_2 2
|
||||
|
||||
#define SET_FLAG(trigger, waiter, e) AscendC::SetFlag<AscendC::HardEvent::trigger##_##waiter>((e))
|
||||
#define WAIT_FLAG(trigger, waiter, e) AscendC::WaitFlag<AscendC::HardEvent::trigger##_##waiter>((e))
|
||||
#define PIPE_BARRIER(pipe) AscendC::PipeBarrier<PIPE_##pipe>()
|
||||
|
||||
#ifndef __force_inline__
|
||||
#define __force_inline__ inline __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
#endif
|
||||
121
csrc/mla_preprocess/op_kernel/kernel/common_func.h
Normal file
121
csrc/mla_preprocess/op_kernel/kernel/common_func.h
Normal file
@@ -0,0 +1,121 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef INCLUDE_COMMON_FUNC_H
|
||||
#define INCLUDE_COMMON_FUNC_H
|
||||
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
#ifdef __CCE_KT_TEST__
|
||||
#include "stub_def.h"
|
||||
#include "stub_fun.h"
|
||||
#else
|
||||
#include "kernel_macros.h"
|
||||
#endif
|
||||
|
||||
template <uint32_t ALIGN, typename T = uint32_t>
|
||||
inline __aicore__ T RoundUp(const T val)
|
||||
{
|
||||
static_assert(ALIGN != 0, "align must not be zero");
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
T align = ALIGN;
|
||||
if (val + align - 1 < val) {
|
||||
return val;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __aicore__ T RoundUp(const T val, const T align)
|
||||
{
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
if (align == 0 || val + align - 1 < val) {
|
||||
return val;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
template <uint32_t DIVISOR, typename T = uint32_t>
|
||||
inline __aicore__ T CeilDiv(const T dividend)
|
||||
{
|
||||
static_assert(DIVISOR != 0, "align must not be zero");
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
T divisor = DIVISOR;
|
||||
if (dividend + divisor - 1 < dividend) {
|
||||
return dividend;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr T T_MAX = std::numeric_limits<T>::max();
|
||||
|
||||
template <typename T>
|
||||
inline __aicore__ T CeilDiv(const T dividend, const T divisor)
|
||||
{
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
if (divisor == 0 || dividend + divisor - 1 < dividend) {
|
||||
return T_MAX<T>;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(const T lhs, const T rhs)
|
||||
{
|
||||
return lhs < rhs ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint32_t BlockSize()
|
||||
{
|
||||
return 32 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize()
|
||||
{
|
||||
return 512 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>() * BlockSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>() * MatrixSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize()
|
||||
{
|
||||
return 32 * 1024 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
#endif
|
||||
36
csrc/mla_preprocess/op_kernel/kernel/hardware.h
Normal file
36
csrc/mla_preprocess/op_kernel/kernel/hardware.h
Normal file
@@ -0,0 +1,36 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_HARDWARE_H
|
||||
#define INCLUDE_HARDWARE_H
|
||||
|
||||
enum class ArchType { ASCEND_V220, ASCEND_V200, ASCEND_M200 };
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct HardwareInfo {
|
||||
static uint32_t const l2BW = 5;
|
||||
static uint32_t const hbmBW = 1;
|
||||
static uint32_t const supportMix = 0;
|
||||
static uint32_t const l1Size = 512 * 1024;
|
||||
static uint32_t const l0ASize = 64 * 1024;
|
||||
static uint32_t const l0BSize = 64 * 1024;
|
||||
static uint32_t const l0CSize = 128 * 1024;
|
||||
static uint32_t const l2Size = 192 * 1024 * 1024;
|
||||
static uint32_t const biasSize = 1024;
|
||||
static uint32_t const fixBufSize = 7 * 1024;
|
||||
static uint32_t const ubSize = 192 * 1024;
|
||||
static uint32_t const fractalSize = 512;
|
||||
static uint32_t const l1l0BlockSize = 32;
|
||||
static uint32_t const btBlockSize = 64;
|
||||
static uint32_t const fbBlockSize = 128;
|
||||
};
|
||||
|
||||
#endif
|
||||
92
csrc/mla_preprocess/op_kernel/kernel/iterator.h
Normal file
92
csrc/mla_preprocess/op_kernel/kernel/iterator.h
Normal file
@@ -0,0 +1,92 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_ITERTOR_H
|
||||
#define INCLUDE_ITERTOR_H
|
||||
|
||||
#include "common_func.h"
|
||||
#include "hardware.h"
|
||||
#include "kernel_operator.h"
|
||||
#include "layout.h"
|
||||
#include "mem.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// gm_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, DataFormat FormatInGM, DataFormat FormatInL1>
|
||||
struct gm_to_l1 {
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor, AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
|
||||
uint32_t dTileCeil, uint32_t dVal) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_a
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
|
||||
struct l1_to_l0_a {
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil, uint32_t kPartCeil, uint32_t mSrcStride, uint32_t kSrcStride,
|
||||
uint32_t mDstStride, uint32_t kDstStride) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_b
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
|
||||
struct l1_to_l0_b {
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil, uint32_t kPartCeil, uint32_t nSrcStride, uint32_t kSrcStride,
|
||||
uint32_t nDstStride, uint32_t kDstStride) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_gm
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, DataFormat OutFormatType, typename OutDataType, typename L0CDataType>
|
||||
struct l0c_to_gm {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<OutDataType> gmTensor, AscendC::LocalTensor<L0CDataType> l0cTensor,
|
||||
uint32_t mTileActual, uint32_t nTileActual, uint32_t mTileCeil, uint32_t nActual,
|
||||
uint8_t unitFlag = 0) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, DataFormat LayoutOut, typename ElementOut, typename ElementIn>
|
||||
struct l0c_to_l1 {
|
||||
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor, AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
AscendC::LocalTensor<uint64_t> deqTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t mTileCeil, uint32_t nActual) {};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_bt {
|
||||
__aicore__ l1_to_bt(uint64_t dst, const AscendC::LocalTensor<DataType> &src, uint16_t convControl, uint16_t nBurst,
|
||||
uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) {};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_fb {
|
||||
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst, AscendC::LocalTensor<DataType> &src, uint16_t burstNum,
|
||||
uint16_t burstLen, uint16_t srcGap, uint16_t dstGap) {};
|
||||
};
|
||||
|
||||
#include "iterators/gm_to_l1_iterator.inc"
|
||||
#include "iterators/gm_to_ub_iterator.inc"
|
||||
#include "iterators/l0c_to_gm_iterator.inc"
|
||||
#include "iterators/l0c_to_l1_iterator.inc"
|
||||
#include "iterators/l0c_to_ub_iterator.inc"
|
||||
#include "iterators/l1_to_bt_iterator.inc"
|
||||
#include "iterators/l1_to_fb_iterator.inc"
|
||||
#include "iterators/l1_to_l0_iterator.inc"
|
||||
#include "iterators/l1_to_ub_iterator.inc"
|
||||
#endif
|
||||
@@ -0,0 +1,162 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
// Partial specialization for V220, ND_in, ND_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ND> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
AscendC::DataCopy(l1Tensor, // dst
|
||||
gmTensor, // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
CeilDiv<BLOCK_SIZE>(nTileActual * dTileActual), // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for NZ_in, NZ_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
uint64_t srcStride = nVal - nTileCeil;
|
||||
if (srcStride < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor, // dst
|
||||
gmTensor, // src
|
||||
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
srcStride, // srcGap
|
||||
0)); // dstGap
|
||||
} else {
|
||||
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; i++) {
|
||||
uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE;
|
||||
uint64_t srcOffset = i * nVal * BLOCK_SIZE;
|
||||
AscendC::DataCopy(l1Tensor[dstOffset], // dst
|
||||
gmTensor[srcOffset], // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, ND_in, ND_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
if (dVal < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
nTileActual, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
dVal, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
1, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
} else {
|
||||
for (uint32_t i = 0; i < nTileActual; i++) {
|
||||
AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE],
|
||||
gmTensor[i * dVal],
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
1, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
0, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
0, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, ND_in, NZ_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ZN> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
if (dVal < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
nTileActual, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
dVal, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
1, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
} else {
|
||||
for (uint32_t i = 0; i < nTileActual; ++i) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
1, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
0, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
0, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,89 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct gm_to_ub {
|
||||
__aicore__ inline gm_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct gm_to_ub_align {
|
||||
__aicore__ inline gm_to_ub_align(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
|
||||
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
|
||||
{
|
||||
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0),
|
||||
AscendC::DataCopyPadExtParams<DType>(false, leftPaddingNum, rightPaddingNum, 0));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct ub_to_ub {
|
||||
__aicore__ inline ub_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType, DataFormat InDataFormat = DataFormat::ND,
|
||||
DataFormat OutDataFormat = DataFormat::ND>
|
||||
struct ub_to_gm {
|
||||
__aicore__ inline ub_to_gm(AscendC::GlobalTensor<DataType> dstTensor, AscendC::LocalTensor<DataType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType> struct ub_to_gm<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ ub_to_gm(AscendC::GlobalTensor<DataType> gmTensor, AscendC::LocalTensor<DataType> ubTensor,
|
||||
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
|
||||
uint32_t dTileCeil, uint32_t dVal)
|
||||
{
|
||||
constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
uint64_t dstStride = nVal - nTileCeil;
|
||||
if (dstStride < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(gmTensor, // dst
|
||||
ubTensor, // src
|
||||
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
dstStride)); // dstGap
|
||||
} else {
|
||||
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; ++i) {
|
||||
uint64_t dstOffset = i * nVal * BLOCK_SIZE;
|
||||
uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE;
|
||||
AscendC::DataCopy(gmTensor[dstOffset], // dst
|
||||
ubTensor[srcOffset], // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct ub_to_gm_align {
|
||||
__aicore__ inline ub_to_gm_align(AscendC::GlobalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
|
||||
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
|
||||
{
|
||||
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0));
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,228 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
constexpr uint32_t BLOCK_NUM = 16;
|
||||
constexpr uint32_t BLOCK_SIZE_INT8 = 32;
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, float> {
|
||||
/**
|
||||
* @brief Copy data from L0C buffer to global memory, partial specialized for
|
||||
*
|
||||
* @param gmTensor the destination tensor on global memory, which is stored in ND format.
|
||||
* @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format.
|
||||
* @param mTileActual the m-direction size of the matrix in L0C buffer.
|
||||
* @param nTileActual the n-direction size of the matrix in L0C buffer.
|
||||
* @param srcStride the source stride between the adjacent fractal matrix along n-direction in unit of C0_SIZE.
|
||||
* @param dstStride the leading dimension of the destination matrix in unit of element.
|
||||
*/
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322F16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::F322F16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, int32_t> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::VDEQF16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<int32_t> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, __bf16, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322BF16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::F322BF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
// Partial specialization ND, float
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, float, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<float> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<float, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::NoQuant};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::NZ, half, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322F16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, float, AscendC::CFG_NZ>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride - (nTileActual * sizeof(half) / sizeof(float)));
|
||||
intriParams.quantParams = {QuantMode_t::F322F16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, int32_t, int32_t> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<int32_t> gmTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<int32_t, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<int32_t> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,42 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization ZN, half, int32_t
|
||||
template <ArchType ArchTag>
|
||||
struct l0c_to_l1<ArchTag, DataFormat::ZN, half, int32_t> {
|
||||
using ElementOut = half;
|
||||
using ElementIn = int32_t;
|
||||
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor,
|
||||
AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
AscendC::LocalTensor<uint64_t> deqTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t nActual)
|
||||
{
|
||||
constexpr uint32_t BLOCK_NUM = 16;
|
||||
constexpr uint32_t BLOCK_SIZE = 32;
|
||||
AscendC::FixpipeParams<ElementIn> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE),
|
||||
0,
|
||||
mTileCeil - static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) *
|
||||
sizeof(ElementOut) / sizeof(ElementIn));
|
||||
intriParams.nz2ndParams = {false, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams);
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,71 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_ub
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization ZN, half, int32_t
|
||||
template <ArchType ArchTag, typename ElementIn, typename ElementOut, bool MatrixMode = true>
|
||||
struct l0c_to_ub {
|
||||
__aicore__ l0c_to_ub(AscendC::LocalTensor<ElementOut> ubTensor,
|
||||
AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
constexpr auto mode =
|
||||
MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR;
|
||||
AscendC::DataCopy(ubTensor,
|
||||
l0cTensor,
|
||||
AscendC::DataCopyParams(nBurst, // count
|
||||
lenBurst, // len
|
||||
srcStride, // srcStrideIn
|
||||
dstStride), // dstStrideIn
|
||||
AscendC::DataCopyEnhancedParams(mode, // blockModeIn
|
||||
AscendC::DeqScale::DEQ_NONE, // deqScaleIn
|
||||
0, // deqValueIn
|
||||
0, // sidStoreModeIn
|
||||
false, // isReluIn
|
||||
pad_t::PAD_NONE, // padModeIn
|
||||
0) // padValueIn
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct l0c_to_ub<ArchTag, int32_t, half> {
|
||||
__aicore__ l0c_to_ub(AscendC::LocalTensor<half> ubTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(ubTensor,
|
||||
l0cTensor,
|
||||
AscendC::DataCopyParams(nBurst, // count
|
||||
lenBurst, // len
|
||||
srcStride, // srcStrideIn
|
||||
dstStride), // dstStrideIn
|
||||
AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn
|
||||
AscendC::DeqScale::VDEQ16, // deqScaleIn
|
||||
0, // deqValueIn
|
||||
0, // sidStoreModeIn
|
||||
false, // isReluIn
|
||||
pad_t::PAD_NONE, // padModeIn
|
||||
0) // padValueIn
|
||||
);
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,39 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_bt
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for V220
|
||||
template <typename DataType>
|
||||
struct l1_to_bt<ArchType::ASCEND_V220, DataType> {
|
||||
__aicore__ l1_to_bt(uint64_t dst,
|
||||
const AscendC::LocalTensor<DataType> &src,
|
||||
uint16_t convControl,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcGap,
|
||||
uint16_t dstGap)
|
||||
{
|
||||
AscendC::LocalTensor<DataType> dstTensor;
|
||||
dstTensor.InitBuffer(dst, nBurst * lenBurst);
|
||||
dstTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
|
||||
AscendC::DataCopy(dstTensor,
|
||||
src,
|
||||
AscendC::DataCopyParams(nBurst, // nBurst
|
||||
lenBurst, // lenBurst
|
||||
srcGap, // srcGap
|
||||
dstGap)); // dstGap
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_fb
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for V220
|
||||
template <typename DataType>
|
||||
struct l1_to_fb<ArchType::ASCEND_V220, DataType> {
|
||||
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst,
|
||||
AscendC::LocalTensor<DataType> &src,
|
||||
uint16_t burstNum,
|
||||
uint16_t burstLen,
|
||||
uint16_t srcGap,
|
||||
uint16_t dstGap)
|
||||
{
|
||||
dst.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2PIPE2GM);
|
||||
AscendC::DataCopy(dst,
|
||||
src,
|
||||
AscendC::DataCopyParams(burstNum, // nBurst
|
||||
burstLen, // lenBurst
|
||||
srcGap, // srcGap
|
||||
dstGap)); // dstGap);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,310 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_a
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for vector
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose>
|
||||
struct l1_to_l0_a<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(l0Tensor,
|
||||
l1Tensor,
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
kPartCeil, // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // sid
|
||||
kDstStride, // dstStride
|
||||
IsTransPose, // transpose
|
||||
0)); // addrCalMode
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for no transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], // dst
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // src
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(kPartCeil / BLOCK_SIZE), // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // sid
|
||||
kDstStride - 1, // dstStride
|
||||
false, // transpose
|
||||
0)); // addrCalMode
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, true, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE],
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0,
|
||||
static_cast<uint16_t>(kPartCeil / BLOCK_NUM_PER_FRACTAL),
|
||||
kSrcStride,
|
||||
0,
|
||||
kDstStride - 1,
|
||||
true,
|
||||
0));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::NZ, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
// 16 * 32
|
||||
static constexpr uint32_t ROW_BLOCK_SIZE = 16;
|
||||
static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil],
|
||||
l1Tensor[i * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0,
|
||||
static_cast<uint16_t>(kPartCeil / COL_BLOCK_SIZE),
|
||||
mTileCeil / ROW_BLOCK_SIZE,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
0));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l1_to_l0_a<ArchType::ASCEND_V220, int8_t, true, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 512
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; // 16
|
||||
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<int8_t> l0Tensor,
|
||||
AscendC::LocalTensor<int8_t> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint64_t i = 0; i < mTileCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
|
||||
AscendC::LoadDataWithTranspose(
|
||||
l0Tensor[i * mDstStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // dstLocalTensor
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // srcLocalTensor
|
||||
AscendC::LoadData2dTransposeParams(0, // baseIdx
|
||||
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(kPartCeil)), // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // dstGap
|
||||
mDstStride - 1)); // dstFracGap
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_b
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for vector
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose>
|
||||
struct l1_to_l0_b<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(
|
||||
l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct l1_to_l0_b<ArchTag, int8_t, true, DataFormat::NZ, DataFormat::ZN> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
using DataType = int8_t;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) {
|
||||
AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE],
|
||||
l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE],
|
||||
AscendC::LoadData2dTransposeParams(0, // startIndexIn
|
||||
kPartCeil / BLOCK_SIZE, // repeatTimesIn
|
||||
nTileCeil / BLOCK_SIZE, // srcStrideIn
|
||||
1, // dstGapIn
|
||||
0, // dstfracGapIn
|
||||
0) // addrModeIn
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for no transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_b<ArchTag, DataType, false, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE],
|
||||
l1Tensor[i * kSrcStride * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(nTileCeil / BLOCK_SIZE), // repeat
|
||||
nSrcStride, // srcStride
|
||||
0, // sid
|
||||
nDstStride - 1, // dstStride
|
||||
true, // transpose
|
||||
0)); // addrCalMode
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_b<ArchTag, DataType, true, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(
|
||||
l0Tensor,
|
||||
l1Tensor,
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat
|
||||
1, // srcStride
|
||||
0, // sid
|
||||
0, // dstStride
|
||||
false, // transpose
|
||||
0)); // addr_cal_mode_t
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l1_to_l0_b<ArchType::ASCEND_V220, int8_t, false, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 16
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<int8_t> l0Tensor,
|
||||
AscendC::LocalTensor<int8_t> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint64_t i = 0; i < kPartCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
|
||||
AscendC::LoadDataWithTranspose(
|
||||
l0Tensor[i * kDstStride * FRACTAL_SIZE], // dstLocalTensor
|
||||
l1Tensor[i * kSrcStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // srcLocalTensor
|
||||
AscendC::LoadData2dTransposeParams(0, // baseIdx
|
||||
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(nTileCeil)), // repeat
|
||||
nSrcStride / NUM_FRACTAL_PER_ITER, // srcStride
|
||||
1, // dstGap
|
||||
0)); // dstFracGap
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_ub
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_ub {
|
||||
__aicore__ l1_to_ub(AscendC::LocalTensor<DataType> ubTensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// ub_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct ub_to_l1 {
|
||||
__aicore__ ub_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::LocalTensor<DataType> ubTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
395
csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h
Normal file
395
csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h
Normal file
@@ -0,0 +1,395 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
|
||||
#define ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using AscendC::HardEvent;
|
||||
|
||||
__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y)
|
||||
{
|
||||
return y == 0 ? 0 : ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t RoundUp(uint32_t x, uint32_t y = 16)
|
||||
{
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t Min(uint32_t x, uint32_t y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t Max(uint32_t x, uint32_t y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyIn(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
|
||||
DataCopy(local, gm[offset], count);
|
||||
queue.EnQue(local);
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyOut(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
|
||||
DataCopy(gm[offset], local, count);
|
||||
queue.FreeTensor(local);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CastFrom16To32(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &in,
|
||||
uint32_t count)
|
||||
{
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CastFrom32To16(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
uint32_t count)
|
||||
{
|
||||
if constexpr (AscendC::IsSameType<T, half>::value) {
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE,
|
||||
count); // 310p cast fp32->half 只能用CAST_NONE,这里拉齐310p和910b
|
||||
} else { // bf16
|
||||
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void CastFromF16ToI8(const AscendC::LocalTensor<int8_t> &out, const AscendC::LocalTensor<half> &in,
|
||||
half quantMin, uint32_t count)
|
||||
{
|
||||
Maxs(in, in, quantMin, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mins(in, in, (half)127, count); // 127: limit
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
|
||||
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
|
||||
#else
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
|
||||
#endif
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyInAndCastF32(const AscendC::LocalTensor<float> &out, const AscendC::GlobalTensor<T> &gm,
|
||||
Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
CopyIn(gm, queue, offset, count);
|
||||
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
|
||||
Cast(out, local, AscendC::RoundMode::CAST_NONE, count);
|
||||
queue.FreeTensor(local);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void Cast16AndCopyOut(const AscendC::LocalTensor<float> &in, const AscendC::GlobalTensor<T> &gm,
|
||||
Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
|
||||
CastFrom32To16(local, in, count);
|
||||
queue.EnQue(local);
|
||||
CopyOut(gm, queue, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T ComputeSum(const AscendC::LocalTensor<T> &in, const AscendC::LocalTensor<T> &tmp,
|
||||
const AscendC::LocalTensor<T> &workLocal, uint32_t count)
|
||||
{
|
||||
#if __CCE_AICORE__ == 100
|
||||
float sum = 0;
|
||||
int64_t elementNumPerRep = AscendC::ONE_REPEAT_BYTE_SIZE / sizeof(T);
|
||||
AscendC::LocalTensor<T> src = in;
|
||||
while (count > elementNumPerRep) {
|
||||
int64_t repeatTimes = count / elementNumPerRep;
|
||||
int64_t tailCount = count % elementNumPerRep;
|
||||
int64_t bodyCount = repeatTimes * elementNumPerRep;
|
||||
if (repeatTimes > 0) {
|
||||
AscendC::AscendCUtils::SetMask<T>(elementNumPerRep);
|
||||
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)src.GetPhyAddr(), repeatTimes, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0); // PipeBarrier(PIPE_V)?
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
if (tailCount != 0) {
|
||||
AscendC::AscendCUtils::SetMask<T>(tailCount);
|
||||
vcadd((__ubuf__ T *)tmp[bodyCount].GetPhyAddr(), (__ubuf__ T *)src[bodyCount].GetPhyAddr(), 1, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
sum += tmp.GetValue(bodyCount);
|
||||
}
|
||||
|
||||
count = repeatTimes;
|
||||
src = tmp;
|
||||
}
|
||||
|
||||
if (count > 1) {
|
||||
AscendC::AscendCUtils::SetMask<T>(count);
|
||||
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)tmp.GetPhyAddr(), 1, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
sum += tmp.GetValue(0);
|
||||
return sum;
|
||||
#else
|
||||
ReduceSum(tmp, in, workLocal, count);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
return tmp.GetValue(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline float ComputeSliceSquareSum(const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<float> &tmp,
|
||||
const AscendC::LocalTensor<float> &workLocal, uint32_t count)
|
||||
{
|
||||
Mul(tmp, in, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
return ComputeSum(tmp, tmp, workLocal, count);
|
||||
}
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
|
||||
uint32_t precisionMode, uint32_t gemmaMode,
|
||||
const AscendC::LocalTensor<float> &tmp)
|
||||
{
|
||||
float value = 1.0;
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(tmp, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
if (precisionMode == 0) {
|
||||
CastFrom16To32(in, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (gemmaMode == 1) {
|
||||
Adds(in, in, value, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Mul(in, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(out, in, count);
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
CastFrom32To16(out, tmp, count);
|
||||
Mul(out, out, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t gemmaMode>
|
||||
__aicore__ inline void CastGAndIsGemmaMode(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &gamma,
|
||||
uint32_t count)
|
||||
{
|
||||
Cast(out, gamma, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
float value = 1.0;
|
||||
if constexpr (gemmaMode == 1) {
|
||||
Adds(out, out, value, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t precisionMode>
|
||||
__aicore__ inline void ComputeRmsNormFast(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
|
||||
const AscendC::LocalTensor<float> &tmp,
|
||||
const AscendC::LocalTensor<float> &fp32_g)
|
||||
{
|
||||
float value = 1.0;
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(tmp, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (precisionMode == 0) {
|
||||
Mul(in, fp32_g, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(out, in, count);
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
CastFrom32To16(out, tmp, count);
|
||||
Mul(out, out, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <bool WITH_BETA = true>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<half> &gamma,
|
||||
const AscendC::LocalTensor<half> &beta, const AscendC::LocalTensor<float> &tmp,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(out, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom16To32(tmp, gamma, count);
|
||||
Mul(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (WITH_BETA) {
|
||||
CastFrom16To32(tmp, beta, count);
|
||||
Add(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
float reciprocal_of_rms, const AscendC::LocalTensor<T> &gamma,
|
||||
const AscendC::LocalTensor<float> &tmp, const AscendC::LocalTensor<T> &res_out,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(tmp, reciprocal_of_rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom16To32(tmp, gamma, count);
|
||||
Mul(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(res_out, out, count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeResidualAdd(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in,
|
||||
const AscendC::LocalTensor<T> &resIn, uint32_t count)
|
||||
{
|
||||
Add(out, in, resIn, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeMean(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in, T aveNum,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(out, aveNum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
T sum = ComputeSum(out, out, out, count);
|
||||
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
Duplicate(out, sum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeLayerNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<float> &mean, float eps, float aveNum,
|
||||
const AscendC::LocalTensor<T> &gamma, const AscendC::LocalTensor<T> &beta,
|
||||
uint32_t count)
|
||||
{
|
||||
Sub(in, in, mean, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Muls(out, out, aveNum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSum(out, out, out, count);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
float var = out.GetValue(0);
|
||||
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
Duplicate(out, var, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(out, out, eps, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Sqrt(out, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
Div(out, in, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
Cast(in, gamma, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, out, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(in, beta, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(out, out, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFp16ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<half> &in, const AscendC::LocalTensor<half> &tmp,
|
||||
half scale, half offset, half quantMin, uint32_t count)
|
||||
{
|
||||
Muls(tmp, in, scale, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(tmp, tmp, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFromF16ToI8(out, tmp, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<half> &tmp, half scale, half offset,
|
||||
half quantMin, uint32_t count)
|
||||
{
|
||||
CastFrom32To16(tmp, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ComputeFp16ToI8Quant(out, tmp, tmp, scale, offset, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeHighPrecisionFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<half> &tmp, float scale,
|
||||
float offset, half quantMin, uint32_t count)
|
||||
{
|
||||
Muls(in, in, scale, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(in, in, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(tmp, in, count);
|
||||
CastFromF16ToI8(out, tmp, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyGmTilingToUb(__ubuf__ uint8_t *&tilingInUb, const __gm__ uint8_t *tilingInGm,
|
||||
size_t tilingSize, AscendC::TPipe *pipe)
|
||||
{
|
||||
uint32_t roundTilingSize = RoundUp(tilingSize, 32);
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> tilingBuf;
|
||||
AscendC::GlobalTensor<uint8_t> tilingGm;
|
||||
|
||||
tilingGm.SetGlobalBuffer((__gm__ uint8_t *)tilingInGm);
|
||||
pipe->InitBuffer(tilingBuf, roundTilingSize);
|
||||
|
||||
AscendC::LocalTensor<uint8_t> tilingUb = tilingBuf.Get<uint8_t>();
|
||||
AscendC::DataCopy(tilingUb, tilingGm, roundTilingSize);
|
||||
|
||||
tilingInUb = (__ubuf__ uint8_t *)tilingUb.GetPhyAddr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline uint32_t GetReduceSumWorkLocalSize(uint32_t sliceSize)
|
||||
{
|
||||
uint32_t elementsPerBlock = 32 / sizeof(T);
|
||||
uint32_t elementsPerRepeat = 256 / sizeof(T);
|
||||
|
||||
uint32_t firstMaxRepeat = sliceSize < elementsPerRepeat ? 1u : (sliceSize / elementsPerRepeat);
|
||||
uint32_t iter1OutputCount = firstMaxRepeat;
|
||||
uint32_t iter1AlignEnd = RoundUp(iter1OutputCount, elementsPerBlock);
|
||||
return iter1AlignEnd;
|
||||
}
|
||||
|
||||
#endif
|
||||
18
csrc/mla_preprocess/op_kernel/kernel/layout.h
Normal file
18
csrc/mla_preprocess/op_kernel/kernel/layout.h
Normal file
@@ -0,0 +1,18 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef INCLUDE_LAYOUT_H
|
||||
#define INCLUDE_LAYOUT_H
|
||||
|
||||
enum class DataFormat { ND = 0, NZ, ZN, ZZ, NN, VECTOR };
|
||||
|
||||
#endif
|
||||
82
csrc/mla_preprocess/op_kernel/kernel/mem.h
Normal file
82
csrc/mla_preprocess/op_kernel/kernel/mem.h
Normal file
@@ -0,0 +1,82 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_MEM_H
|
||||
#define INCLUDE_MEM_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_event.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
enum class BufferType { ASCEND_UB, ASCEND_CB, ASCEND_L0A, ASCEND_L0B, ASCEND_L0C, ASCEND_MAX };
|
||||
|
||||
template <BufferType BufferType_>
|
||||
__aicore__ constexpr AscendC::TPosition GetPosition()
|
||||
{
|
||||
if constexpr (BufferType_ == BufferType::ASCEND_UB) {
|
||||
return AscendC::TPosition::VECIN;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_CB) {
|
||||
return AscendC::TPosition::A1;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0A) {
|
||||
return AscendC::TPosition::A2;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0B) {
|
||||
return AscendC::TPosition::B2;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0C) {
|
||||
return AscendC::TPosition::CO1;
|
||||
}
|
||||
return AscendC::TPosition::GM;
|
||||
}
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct AsdopsBuffer {
|
||||
public:
|
||||
__aicore__ AsdopsBuffer()
|
||||
{
|
||||
constexpr uint32_t bufferSize[(uint32_t)BufferType::ASCEND_MAX] = {
|
||||
HardwareInfo<ArchTag>::ubSize, HardwareInfo<ArchTag>::l1Size, HardwareInfo<ArchTag>::l0ASize,
|
||||
HardwareInfo<ArchTag>::l0BSize, HardwareInfo<ArchTag>::l0CSize};
|
||||
#ifdef __DAV_C220_VEC__
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
#elif defined(__DAV_C220_CUBE__)
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
|
||||
#else
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <BufferType BufferType_, typename DstDataType = half>
|
||||
__aicore__ AscendC::LocalTensor<DstDataType> GetBuffer(const uint32_t offset) const
|
||||
{
|
||||
return tensor[(uint32_t)BufferType_][offset].template ReinterpretCast<DstDataType>();
|
||||
}
|
||||
|
||||
public:
|
||||
AscendC::LocalTensor<uint8_t> tensor[(uint32_t)BufferType::ASCEND_MAX];
|
||||
};
|
||||
|
||||
#endif
|
||||
67
csrc/mla_preprocess/op_kernel/kernel/mma.h
Normal file
67
csrc/mla_preprocess/op_kernel/kernel/mma.h
Normal file
@@ -0,0 +1,67 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_MMA_H
|
||||
#define INCLUDE_MMA_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
template <ArchType ArchTag, typename ElementA, typename ElementB, typename AccDTypeC, bool IsTransposeA>
|
||||
struct mmad {
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
|
||||
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
|
||||
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, int8_t, not_vector_A, not TransposeA
|
||||
template <ArchType ArchTag, typename AccDTypeC, typename ElementA, typename ElementB>
|
||||
struct mmad<ArchTag, ElementA, ElementB, AccDTypeC, false> {
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::Mmad(l0cTensor, // C
|
||||
l0aTensor, // A
|
||||
l0bTensor, // B
|
||||
AscendC::MmadParams(mTileActual, // m
|
||||
nTileActual, // n
|
||||
kPartActual, // k
|
||||
unitFlag, // unitFlag
|
||||
false, // cmatrixSource
|
||||
initC)); // cmatrixInitVal
|
||||
};
|
||||
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
|
||||
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::LocalTensor<AccDTypeC> biasTensor;
|
||||
biasTensor.InitBuffer(biasBt, nTileActual);
|
||||
biasTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
|
||||
AscendC::Mmad(l0cTensor, // C
|
||||
l0aTensor, // A
|
||||
l0bTensor, // B
|
||||
biasTensor, // bt
|
||||
AscendC::MmadParams(mTileActual, // m
|
||||
nTileActual, // n
|
||||
kPartActual, // k
|
||||
unitFlag, // unitFlag
|
||||
true, // cmatrixSource
|
||||
false)); // cmatrixInitVal
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
38
csrc/mla_preprocess/op_kernel/kernel/set_fpc.h
Normal file
38
csrc/mla_preprocess/op_kernel/kernel/set_fpc.h
Normal file
@@ -0,0 +1,38 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_SET_FPC_H
|
||||
#define INCLUDE_SET_FPC_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// SetQuantPreAddr
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct SetQuantPreAddr {
|
||||
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor) {};
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct SetQuantPreAddr<ArchType::ASCEND_V220, DataType> {
|
||||
static constexpr uint32_t QUANT_PRE_ADDR_MASK = 0xffff;
|
||||
static constexpr uint32_t USELESS_BIT_NUM = 7;
|
||||
static constexpr uint32_t QUANT_PRE_BIT_POS_IN_FPC = 8;
|
||||
|
||||
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor)
|
||||
{
|
||||
uint64_t quantPreAddr = (uint64_t)(__fbuf__ uint64_t *)quantPreTensor.GetPhyAddr();
|
||||
AscendC::SetFixPipeConfigImpl(quantPreTensor);
|
||||
};
|
||||
};
|
||||
#endif
|
||||
274
csrc/mla_preprocess/op_kernel/kernel/simd.h
Normal file
274
csrc/mla_preprocess/op_kernel/kernel/simd.h
Normal file
@@ -0,0 +1,274 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_SIMD_H
|
||||
#define INCLUDE_SIMD_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcgadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cgadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
|
||||
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
|
||||
{
|
||||
AscendC::BlockReduceSum<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void add_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Add<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vadds
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void adds_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, DType scalarValue,
|
||||
uint8_t repeat, uint8_t dstBlockStride, uint8_t srcBlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Adds<DType, false>(
|
||||
dst, src, scalarValue, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::RepeatReduceSum<DType, false>(dst, src, repeat, 0, 0, srcBlockStride, dstRepeatStride, srcRepeatStride);
|
||||
}
|
||||
/////////////////////////////////////////////////////
|
||||
// vbrcb
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void brcb_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint16_t dstBlockStride,
|
||||
uint16_t dstRepeatStride, uint8_t repeat)
|
||||
{
|
||||
AscendC::Brcb(dst, src, repeat, AscendC::BrcbRepeatParams(dstBlockStride, dstRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType, AscendC::ReduceOrder OrderType>
|
||||
__aicore__ inline void cmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
#if defined(__DAV_C220_VEC__)
|
||||
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
|
||||
srcRepeatStride, OrderType);
|
||||
#else
|
||||
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
|
||||
srcRepeatStride);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vconv
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
|
||||
__aicore__ inline void conv_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
if constexpr (std::is_same<DTypeIn, float>::value && std::is_same<DTypeOut, __bf16>::value) {
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
} else {
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_NONE, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vconv_f322bf16r
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
|
||||
__aicore__ inline void convr_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vdiv
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void div_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Div<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vexp
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void exp_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Exp<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void max_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Max<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmul
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void mul_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Mul<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmuls
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void muls_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Muls<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vsub
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void sub_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Sub<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmaxs
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void maxs_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Maxs<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmins
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void mins_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Mins<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vsqrt
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void sqrt_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Sqrt<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vln
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void ln_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Ln<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vtranspose
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void tranpose_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src)
|
||||
{
|
||||
AscendC::Transpose(dst, src);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcgmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cgmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
|
||||
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
|
||||
{
|
||||
AscendC::BlockReduceMax<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
|
||||
}
|
||||
#endif
|
||||
69
csrc/mla_preprocess/op_kernel/kernel/utils.h
Normal file
69
csrc/mla_preprocess/op_kernel/kernel/utils.h
Normal file
@@ -0,0 +1,69 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_UTILS_H
|
||||
#define INCLUDE_UTILS_H
|
||||
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void CreateCaMatrix(const AscendC::LocalTensor<IN_DTYPE> &dst, const uint16_t repeats,
|
||||
const uint16_t blockNum, const uint16_t dstGap, const IN_DTYPE initValue)
|
||||
{
|
||||
AscendC::InitConstValue<IN_DTYPE>(dst,
|
||||
AscendC::InitConstValueParams<IN_DTYPE>(repeats, blockNum, dstGap, initValue));
|
||||
}
|
||||
__aicore__ inline void SetFftsBaseAddr(uint64_t config)
|
||||
{
|
||||
AscendC::SetSyncBaseAddr(config);
|
||||
}
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void SetPadding(IN_DTYPE padValue)
|
||||
{
|
||||
AscendC::SetLoadDataPaddingValue<IN_DTYPE>(padValue);
|
||||
}
|
||||
__aicore__ inline void SetAtomicnone()
|
||||
{
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
__aicore__ inline void SetMasknorm()
|
||||
{
|
||||
#if __CCE_AICORE__ == 100
|
||||
return;
|
||||
#endif
|
||||
AscendC::SetMaskNorm();
|
||||
}
|
||||
__aicore__ inline void SetNdpara(uint16_t ndNum, uint16_t srcNdStride, uint16_t dstNdStride)
|
||||
{
|
||||
AscendC::SetFixpipeNz2ndFlag(ndNum, srcNdStride, dstNdStride);
|
||||
}
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)
|
||||
{
|
||||
AscendC::SetVectorMask<IN_DTYPE>(maskHigh, maskLow);
|
||||
}
|
||||
__aicore__ inline int64_t GetSubBlockidx()
|
||||
{
|
||||
return AscendC::GetSubBlockIdx();
|
||||
}
|
||||
__aicore__ inline void WaitFlagDev(uint16_t flagId)
|
||||
{
|
||||
AscendC::WaitEvent(flagId);
|
||||
}
|
||||
template <pipe_t pipe, uint8_t mode>
|
||||
__aicore__ inline void FftsCrossCoreSync(uint16_t flagId)
|
||||
{
|
||||
AscendC::CrossCoreSetFlag<mode, pipe>(flagId);
|
||||
}
|
||||
template <typename IN_DTYPE, bool setRelu = false>
|
||||
__aicore__ inline void SetFpc(const AscendC::LocalTensor<IN_DTYPE> &preTensor, bool isUnitFlag = false)
|
||||
{
|
||||
AscendC::SetFixPipeConfig<IN_DTYPE, setRelu>(preTensor, isUnitFlag);
|
||||
}
|
||||
#endif
|
||||
114
csrc/mla_preprocess/op_kernel/mla_preprocess.h
Normal file
114
csrc/mla_preprocess/op_kernel/mla_preprocess.h
Normal file
@@ -0,0 +1,114 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#ifndef __MLA_PREPROCESS_H__
|
||||
#define __MLA_PREPROCESS_H__
|
||||
|
||||
// sync
|
||||
constexpr int32_t QUANT1 = 1;
|
||||
constexpr int32_t MM1 = 2;
|
||||
constexpr int32_t MM1QUANT = 3;
|
||||
constexpr int32_t RMSNORMQUANT2 = 4;
|
||||
constexpr int32_t MM2 = 5;
|
||||
constexpr int32_t MM2QUANT = 6;
|
||||
constexpr int32_t BMM3 = 7;
|
||||
constexpr int32_t BMM3SPLIT = 8;
|
||||
constexpr int32_t MM2OUT = 9;
|
||||
constexpr int32_t EINSUMOUT = 11;
|
||||
constexpr int32_t EINSUMQUANT = 12;
|
||||
|
||||
// ropeConcat
|
||||
constexpr uint32_t ELE_NUM_FP16 = 16; // nums of fp16 elements in one block
|
||||
constexpr uint32_t ELE_NUM_FP32 = 8; // nums of fp32 elements in one block
|
||||
constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // stride, 8 * 32 = 256
|
||||
|
||||
// rmsNormQuant
|
||||
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
|
||||
constexpr float ZERO = 0;
|
||||
constexpr uint32_t BUF_FACTOR = 3; // 1(g) + 1(sqx) + 1(sum) = 3
|
||||
constexpr uint32_t OFFSET_GAMMA = 0; // the offset of gamma is 0
|
||||
constexpr uint32_t OFFSET_SQX = 1; // the offset of sqx is 1
|
||||
constexpr uint32_t OFFSET_SUM = 2; // the offset of sum is 2
|
||||
constexpr uint32_t OFFSET_WORKSPACE = 3; // the offset of workspace is 3
|
||||
constexpr uint32_t REPEAT_TIME_256 = 256; // 128 default stride
|
||||
constexpr uint32_t REPEAT_TIME_128 = 128; // 128 default stride
|
||||
constexpr uint32_t REPEAT_TIME_64 = 64; // 64 default stride
|
||||
|
||||
constexpr uint8_t CACHE_MODE_KVCACHE = 0; // single input single output
|
||||
constexpr uint8_t CACHE_MODE_KROPE_CTKV = 1; // double in and double out
|
||||
constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format/quant int8
|
||||
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
|
||||
|
||||
// pp matmul
|
||||
constexpr uint32_t HIDDTEN_STATE = 7168;
|
||||
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_VECTOR_SIZE = 64;
|
||||
constexpr uint32_t MM1_OUT_SIZE = 2112;
|
||||
constexpr uint32_t SPLIT_SIZE_ONE = 576;
|
||||
constexpr uint32_t SPLIT_SIZE_TWO = 1536;
|
||||
constexpr uint32_t SPLIT_RMSNRORM_SIZE_ONE = 512;
|
||||
constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128;
|
||||
|
||||
constexpr uint32_t MMSIZE1 = 128 * 192; // 24576
|
||||
constexpr uint32_t MMSIZE2 = 64 * 128; // 8192
|
||||
|
||||
constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB
|
||||
constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB
|
||||
constexpr uint64_t BLOCK_SIZE_16 = 16;
|
||||
constexpr uint64_t BLOCK_SIZE_32 = 32;
|
||||
constexpr uint64_t CUBE_MATRIX_SIZE_512 = 16 * 32; // 16 * 23
|
||||
constexpr uint64_t FB_BUFF_SIZE = 1024 * 7;
|
||||
constexpr uint64_t SCALE_L1_LEN = 4096;
|
||||
constexpr uint64_t BIAS_L1_LEN = 2048;
|
||||
|
||||
constexpr uint64_t CONST_0 = 0;
|
||||
constexpr uint64_t CONST_4 = 4;
|
||||
constexpr uint64_t CONST_8 = 8;
|
||||
constexpr uint64_t CONST_32 = 32;
|
||||
constexpr uint64_t CONST_64 = 64;
|
||||
constexpr uint64_t CONST_128 = 128;
|
||||
|
||||
// ropeConcat
|
||||
constexpr uint32_t ROPE_CONCAT_NUM_BUFFER = 2;
|
||||
|
||||
// rmsNormQuant
|
||||
constexpr uint32_t OFFSET_ABS = 3; // the offset of abs is 3
|
||||
constexpr uint32_t OFFSET_WORKSPACE_BF16 = 4; // the offset of workspace is 4
|
||||
|
||||
// sync bf16
|
||||
constexpr int32_t AIC_MM1_START = 2;
|
||||
constexpr int32_t AIC_MM3_START = 3;
|
||||
constexpr int32_t AIC_MM2_START = 6;
|
||||
constexpr int32_t MMAIC = 7;
|
||||
constexpr int32_t MMAIV = 8;
|
||||
|
||||
constexpr uint32_t MAX_HW_SYNC_COUNTER = 5;
|
||||
constexpr uint32_t SYNC_MODE = 2;
|
||||
|
||||
// TilingKey
|
||||
constexpr uint32_t KEY_FP16_CACHEMODE_0_QUANTMODE_0 = 0;
|
||||
constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
|
||||
|
||||
enum class QuantMode : int32_t {
|
||||
PER_TENSOR_ASYMM_QUANT = 0,
|
||||
PER_TOKEN_SYMM_QUANT,
|
||||
PER_TOKEN_ASYMM_QUANT,
|
||||
NO_QUANT,
|
||||
};
|
||||
|
||||
#endif // __MLA_PREPROCESS_H__
|
||||
299
csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Normal file
299
csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Normal file
@@ -0,0 +1,299 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "../../kernels/types.h"
|
||||
|
||||
#include "mla_preprocess_mix_fp16.hpp"
|
||||
#include "mla_preprocess_mix_bf16.hpp"
|
||||
|
||||
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
||||
|
||||
extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
|
||||
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
|
||||
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
|
||||
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
|
||||
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
|
||||
PRELOAD(2);
|
||||
#endif
|
||||
|
||||
SetAtomicnone();
|
||||
SetMasknorm();
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
SetPadding<uint64_t>((uint64_t)0);
|
||||
SetNdpara(1, 0, 0);
|
||||
#endif
|
||||
|
||||
MlaTilingData mlaTilingData;
|
||||
__gm__ MlaTilingData *tilingData = reinterpret_cast<__gm__ MlaTilingData *>(tiling);
|
||||
|
||||
mlaTilingData.tilingKey = tilingData->tilingKey;
|
||||
mlaTilingData.n = tilingData->n;
|
||||
|
||||
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
|
||||
mlaTilingData.mm1.m = tilingData->mm1.m;
|
||||
mlaTilingData.mm1.k = tilingData->mm1.k;
|
||||
mlaTilingData.mm1.n = tilingData->mm1.n;
|
||||
mlaTilingData.mm1.m0 = tilingData->mm1.m0;
|
||||
mlaTilingData.mm1.k0 = tilingData->mm1.k0;
|
||||
mlaTilingData.mm1.n0 = tilingData->mm1.n0;
|
||||
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop;
|
||||
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop;
|
||||
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop;
|
||||
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop;
|
||||
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount;
|
||||
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK;
|
||||
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim;
|
||||
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat;
|
||||
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen;
|
||||
|
||||
mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch;
|
||||
mlaTilingData.mm2.m = tilingData->mm2.m;
|
||||
mlaTilingData.mm2.k = tilingData->mm2.k;
|
||||
mlaTilingData.mm2.n = tilingData->mm2.n;
|
||||
mlaTilingData.mm2.m0 = tilingData->mm2.m0;
|
||||
mlaTilingData.mm2.k0 = tilingData->mm2.k0;
|
||||
mlaTilingData.mm2.n0 = tilingData->mm2.n0;
|
||||
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop;
|
||||
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop;
|
||||
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop;
|
||||
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop;
|
||||
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount;
|
||||
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK;
|
||||
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim;
|
||||
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat;
|
||||
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen;
|
||||
|
||||
mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch;
|
||||
mlaTilingData.mm3.m = tilingData->mm3.m;
|
||||
mlaTilingData.mm3.k = tilingData->mm3.k;
|
||||
mlaTilingData.mm3.n = tilingData->mm3.n;
|
||||
mlaTilingData.mm3.m0 = tilingData->mm3.m0;
|
||||
mlaTilingData.mm3.k0 = tilingData->mm3.k0;
|
||||
mlaTilingData.mm3.n0 = tilingData->mm3.n0;
|
||||
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop;
|
||||
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop;
|
||||
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop;
|
||||
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop;
|
||||
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount;
|
||||
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK;
|
||||
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim;
|
||||
|
||||
mlaTilingData.perTaskNum = tilingData->perTaskNum;
|
||||
mlaTilingData.resTaskNum = tilingData->resTaskNum;
|
||||
mlaTilingData.numCore = tilingData->numCore;
|
||||
|
||||
mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1;
|
||||
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1;
|
||||
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2;
|
||||
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2;
|
||||
|
||||
mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ;
|
||||
mlaTilingData.headNumQ = tilingData->headNumQ;
|
||||
mlaTilingData.headDim = tilingData->headDim;
|
||||
mlaTilingData.concatSize = tilingData->concatSize;
|
||||
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff;
|
||||
mlaTilingData.ntokens = tilingData->ntokens;
|
||||
mlaTilingData.realCore = tilingData->realCore;
|
||||
mlaTilingData.nlCoreRun = tilingData->nlCoreRun;
|
||||
mlaTilingData.lCoreRun = tilingData->lCoreRun;
|
||||
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb;
|
||||
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime;
|
||||
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast;
|
||||
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime;
|
||||
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast;
|
||||
|
||||
mlaTilingData.esqFrontCore = tilingData->esqFrontCore;
|
||||
mlaTilingData.esqTailCore = tilingData->esqTailCore;
|
||||
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch;
|
||||
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch;
|
||||
mlaTilingData.esqHeadNum = tilingData->esqHeadNum;
|
||||
mlaTilingData.esqColNum = tilingData->esqColNum;
|
||||
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop;
|
||||
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop;
|
||||
mlaTilingData.esqHeadTail = tilingData->esqHeadTail;
|
||||
mlaTilingData.esqColLoop = tilingData->esqColLoop;
|
||||
mlaTilingData.esqColTail = tilingData->esqColTail;
|
||||
|
||||
mlaTilingData.s1Offset = tilingData->s1Offset;
|
||||
mlaTilingData.s2Offset = tilingData->s2Offset;
|
||||
mlaTilingData.s3Offset = tilingData->s3Offset;
|
||||
mlaTilingData.s4Offset = tilingData->s4Offset;
|
||||
mlaTilingData.s5Offset = tilingData->s5Offset;
|
||||
|
||||
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
|
||||
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
|
||||
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
|
||||
GM_ADDR s4 = workspace + static_cast<uint64_t>(mlaTilingData.s4Offset);
|
||||
GM_ADDR s5 = workspace + static_cast<uint64_t>(mlaTilingData.s5Offset);
|
||||
|
||||
switch (mlaTilingData.tilingKey) {
|
||||
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
|
||||
mlaTilingData, tiling);
|
||||
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opFp16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opFp16Cm0Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
||||
opFp16Cm1Qm0(mlaTilingData, tiling);
|
||||
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opFp16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opFp16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm0Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm0Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_3_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm3Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm3Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm3Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
extern void mla_preprocess_impl(
|
||||
void* stream,
|
||||
void* hidden_state,
|
||||
void* gamma1,
|
||||
void* beta1,
|
||||
void* quant_scale1,
|
||||
void* quant_offset1,
|
||||
void* wdqkv,
|
||||
void* bias1,
|
||||
void* gamma2,
|
||||
void* beta2,
|
||||
void* quant_scale2,
|
||||
void* quant_offset2,
|
||||
void* gamma3,
|
||||
void* sin1,
|
||||
void* cos1,
|
||||
void* sin2,
|
||||
void* cos2,
|
||||
void* keycache,
|
||||
void* slot_mapping,
|
||||
void* wuq,
|
||||
void* bias2,
|
||||
void* wuk,
|
||||
void* descale1,
|
||||
void* descale2,
|
||||
void* ctkv_scale,
|
||||
void* qnope_scale,
|
||||
void* q,
|
||||
void* keycache_out,
|
||||
void* q2,
|
||||
void* keycache_out2,
|
||||
void* workspace,
|
||||
void* tiling,
|
||||
const uint32_t block_dim)
|
||||
{
|
||||
mla_preprocess<<<block_dim, nullptr, stream>>>(
|
||||
hidden_state,
|
||||
gamma1,
|
||||
beta1,
|
||||
quant_scale1,
|
||||
quant_offset1,
|
||||
wdqkv,
|
||||
bias1,
|
||||
gamma2,
|
||||
beta2,
|
||||
quant_scale2,
|
||||
quant_offset2,
|
||||
gamma3,
|
||||
sin1,
|
||||
cos1,
|
||||
sin2,
|
||||
cos2,
|
||||
keycache,
|
||||
slot_mapping,
|
||||
wuq,
|
||||
bias2,
|
||||
wuk,
|
||||
descale1,
|
||||
descale2,
|
||||
ctkv_scale,
|
||||
qnope_scale,
|
||||
q,
|
||||
keycache_out,
|
||||
q2,
|
||||
keycache_out2,
|
||||
workspace,
|
||||
tiling);
|
||||
}
|
||||
|
||||
}
|
||||
2918
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Normal file
2918
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2508
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Normal file
2508
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user