add mla_preprocess kernel (#3226)

### What this PR does / why we need it?

- Adds the `mla_preprocess` custom kernel to provide an optimized
pre-processing operator for Multi-head Latent Attention (MLA) on Ascend
NPUs.
- Wires the new kernel into the C++ extension pipeline so vLLM can
invoke it directly, cutting Python-side tensor shuffling and memory
copies that previously bottlenecked MLA compilation paths.

### Does this PR introduce any user-facing change?

- No. The change only introduces a low-level kernel; public APIs and
inference behavior remain unchanged.

### How was this patch tested?

- Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.

- vLLM version: v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-10-12 07:39:45 +08:00
committed by GitHub
parent 1b1207e3c3
commit bcc313e8f2
32 changed files with 9158 additions and 3 deletions

View File

@@ -0,0 +1,162 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
// Partial specialization for V220, ND_in, ND_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ND> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
AscendC::DataCopy(l1Tensor, // dst
gmTensor, // src
AscendC::DataCopyParams(1, // nBurst
CeilDiv<BLOCK_SIZE>(nTileActual * dTileActual), // lenBurst
0, // srcGap
0)); // dstGap
};
};
// Partial specialization for NZ_in, NZ_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
uint64_t srcStride = nVal - nTileCeil;
if (srcStride < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor, // dst
gmTensor, // src
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
nTileCeil, // lenBurst
srcStride, // srcGap
0)); // dstGap
} else {
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; i++) {
uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE;
uint64_t srcOffset = i * nVal * BLOCK_SIZE;
AscendC::DataCopy(l1Tensor[dstOffset], // dst
gmTensor[srcOffset], // src
AscendC::DataCopyParams(1, // nBurst
nTileCeil, // lenBurst
0, // srcGap
0)); // dstGap
}
}
};
};
// Partial specialization for V220, ND_in, ND_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
if (dVal < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
nTileActual, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
dVal, // srcDValue
nTileCeil, // dstNzC0Stride
1, // dstNzNStride
0)); // dstNzMatrixStride, unused
} else {
for (uint32_t i = 0; i < nTileActual; i++) {
AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE],
gmTensor[i * dVal],
AscendC::Nd2NzParams(1, // ndNum
1, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
0, // srcDValue
nTileCeil, // dstNzC0Stride
0, // dstNzNStride
0)); // dstNzMatrixStride, unused
}
}
};
};
// Partial specialization for V220, ND_in, NZ_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ZN> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
if (dVal < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
nTileActual, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
dVal, // srcDValue
nTileCeil, // dstNzC0Stride
1, // dstNzNStride
0)); // dstNzMatrixStride, unused
} else {
for (uint32_t i = 0; i < nTileActual; ++i) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
1, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
0, // srcDValue
nTileCeil, // dstNzC0Stride
0, // dstNzNStride
0)); // dstNzMatrixStride, unused
}
}
};
};

View File

@@ -0,0 +1,89 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
template <ArchType ArchTag, typename DType> struct gm_to_ub {
__aicore__ inline gm_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DType> struct gm_to_ub_align {
__aicore__ inline gm_to_ub_align(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
{
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0),
AscendC::DataCopyPadExtParams<DType>(false, leftPaddingNum, rightPaddingNum, 0));
};
};
template <ArchType ArchTag, typename DType> struct ub_to_ub {
__aicore__ inline ub_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DataType, DataFormat InDataFormat = DataFormat::ND,
DataFormat OutDataFormat = DataFormat::ND>
struct ub_to_gm {
__aicore__ inline ub_to_gm(AscendC::GlobalTensor<DataType> dstTensor, AscendC::LocalTensor<DataType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DataType> struct ub_to_gm<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ ub_to_gm(AscendC::GlobalTensor<DataType> gmTensor, AscendC::LocalTensor<DataType> ubTensor,
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
uint32_t dTileCeil, uint32_t dVal)
{
constexpr uint32_t STRIDE_LIMIT = 65536;
uint64_t dstStride = nVal - nTileCeil;
if (dstStride < STRIDE_LIMIT) {
AscendC::DataCopy(gmTensor, // dst
ubTensor, // src
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
nTileCeil, // lenBurst
0, // srcGap
dstStride)); // dstGap
} else {
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; ++i) {
uint64_t dstOffset = i * nVal * BLOCK_SIZE;
uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE;
AscendC::DataCopy(gmTensor[dstOffset], // dst
ubTensor[srcOffset], // src
AscendC::DataCopyParams(1, // nBurst
nTileCeil, // lenBurst
0, // srcGap
0)); // dstGap
}
}
};
};
template <ArchType ArchTag, typename DType> struct ub_to_gm_align {
__aicore__ inline ub_to_gm_align(AscendC::GlobalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
{
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0));
};
};

View File

@@ -0,0 +1,228 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
constexpr uint32_t BLOCK_NUM = 16;
constexpr uint32_t BLOCK_SIZE_INT8 = 32;
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, float> {
/**
* @brief Copy data from L0C buffer to global memory, partial specialized for
*
* @param gmTensor the destination tensor on global memory, which is stored in ND format.
* @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format.
* @param mTileActual the m-direction size of the matrix in L0C buffer.
* @param nTileActual the n-direction size of the matrix in L0C buffer.
* @param srcStride the source stride between the adjacent fractal matrix along n-direction in unit of C0_SIZE.
* @param dstStride the leading dimension of the destination matrix in unit of element.
*/
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322F16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::F322F16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, int32_t> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::VDEQF16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<int32_t> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
#ifdef __DAV_C220_CUBE__
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, __bf16, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322BF16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::F322BF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
#endif
// Partial specialization ND, float
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, float, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<float> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<float, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::NoQuant};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::NZ, half, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322F16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, float, AscendC::CFG_NZ>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride - (nTileActual * sizeof(half) / sizeof(float)));
intriParams.quantParams = {QuantMode_t::F322F16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, int32_t, int32_t> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<int32_t> gmTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<int32_t, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<int32_t> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};

View File

@@ -0,0 +1,42 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l0c_to_l1
/////////////////////////////////////////////////////
// Partial specialization ZN, half, int32_t
template <ArchType ArchTag>
struct l0c_to_l1<ArchTag, DataFormat::ZN, half, int32_t> {
using ElementOut = half;
using ElementIn = int32_t;
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor,
AscendC::LocalTensor<ElementIn> l0cTensor,
AscendC::LocalTensor<uint64_t> deqTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t mTileCeil,
uint32_t nActual)
{
constexpr uint32_t BLOCK_NUM = 16;
constexpr uint32_t BLOCK_SIZE = 32;
AscendC::FixpipeParams<ElementIn> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE),
0,
mTileCeil - static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) *
sizeof(ElementOut) / sizeof(ElementIn));
intriParams.nz2ndParams = {false, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams);
};
};

View File

@@ -0,0 +1,71 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l0c_to_ub
/////////////////////////////////////////////////////
// Partial specialization ZN, half, int32_t
template <ArchType ArchTag, typename ElementIn, typename ElementOut, bool MatrixMode = true>
struct l0c_to_ub {
__aicore__ l0c_to_ub(AscendC::LocalTensor<ElementOut> ubTensor,
AscendC::LocalTensor<ElementIn> l0cTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
constexpr auto mode =
MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR;
AscendC::DataCopy(ubTensor,
l0cTensor,
AscendC::DataCopyParams(nBurst, // count
lenBurst, // len
srcStride, // srcStrideIn
dstStride), // dstStrideIn
AscendC::DataCopyEnhancedParams(mode, // blockModeIn
AscendC::DeqScale::DEQ_NONE, // deqScaleIn
0, // deqValueIn
0, // sidStoreModeIn
false, // isReluIn
pad_t::PAD_NONE, // padModeIn
0) // padValueIn
);
};
};
template <ArchType ArchTag>
struct l0c_to_ub<ArchTag, int32_t, half> {
__aicore__ l0c_to_ub(AscendC::LocalTensor<half> ubTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(ubTensor,
l0cTensor,
AscendC::DataCopyParams(nBurst, // count
lenBurst, // len
srcStride, // srcStrideIn
dstStride), // dstStrideIn
AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn
AscendC::DeqScale::VDEQ16, // deqScaleIn
0, // deqValueIn
0, // sidStoreModeIn
false, // isReluIn
pad_t::PAD_NONE, // padModeIn
0) // padValueIn
);
};
};

View File

@@ -0,0 +1,39 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_bt
/////////////////////////////////////////////////////
// Partial specialization for V220
template <typename DataType>
struct l1_to_bt<ArchType::ASCEND_V220, DataType> {
__aicore__ l1_to_bt(uint64_t dst,
const AscendC::LocalTensor<DataType> &src,
uint16_t convControl,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcGap,
uint16_t dstGap)
{
AscendC::LocalTensor<DataType> dstTensor;
dstTensor.InitBuffer(dst, nBurst * lenBurst);
dstTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
AscendC::DataCopy(dstTensor,
src,
AscendC::DataCopyParams(nBurst, // nBurst
lenBurst, // lenBurst
srcGap, // srcGap
dstGap)); // dstGap
}
};

View File

@@ -0,0 +1,36 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_fb
/////////////////////////////////////////////////////
// Partial specialization for V220
template <typename DataType>
struct l1_to_fb<ArchType::ASCEND_V220, DataType> {
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst,
AscendC::LocalTensor<DataType> &src,
uint16_t burstNum,
uint16_t burstLen,
uint16_t srcGap,
uint16_t dstGap)
{
dst.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2PIPE2GM);
AscendC::DataCopy(dst,
src,
AscendC::DataCopyParams(burstNum, // nBurst
burstLen, // lenBurst
srcGap, // srcGap
dstGap)); // dstGap);
}
};

View File

@@ -0,0 +1,310 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_l0_a
/////////////////////////////////////////////////////
// Partial specialization for vector
template <ArchType ArchTag, typename DataType, bool IsTransPose>
struct l1_to_l0_a<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(l0Tensor,
l1Tensor,
AscendC::LoadData2dParams(0, // baseIdx
kPartCeil, // repeat
kSrcStride, // srcStride
0, // sid
kDstStride, // dstStride
IsTransPose, // transpose
0)); // addrCalMode
};
};
// Partial specialization for no transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) {
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], // dst
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // src
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(kPartCeil / BLOCK_SIZE), // repeat
kSrcStride, // srcStride
0, // sid
kDstStride - 1, // dstStride
false, // transpose
0)); // addrCalMode
}
};
};
// Partial specialization for transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, true, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) {
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE],
l1Tensor[i * mSrcStride * FRACTAL_SIZE],
AscendC::LoadData2dParams(0,
static_cast<uint16_t>(kPartCeil / BLOCK_NUM_PER_FRACTAL),
kSrcStride,
0,
kDstStride - 1,
true,
0));
}
};
};
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::NZ, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
// 16 * 32
static constexpr uint32_t ROW_BLOCK_SIZE = 16;
static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) {
AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil],
l1Tensor[i * FRACTAL_SIZE],
AscendC::LoadData2dParams(0,
static_cast<uint16_t>(kPartCeil / COL_BLOCK_SIZE),
mTileCeil / ROW_BLOCK_SIZE,
0,
0,
false,
0));
}
};
};
template <>
struct l1_to_l0_a<ArchType::ASCEND_V220, int8_t, true, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 512
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; // 16
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<int8_t> l0Tensor,
AscendC::LocalTensor<int8_t> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint64_t i = 0; i < mTileCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
AscendC::LoadDataWithTranspose(
l0Tensor[i * mDstStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // dstLocalTensor
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // srcLocalTensor
AscendC::LoadData2dTransposeParams(0, // baseIdx
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(kPartCeil)), // repeat
kSrcStride, // srcStride
0, // dstGap
mDstStride - 1)); // dstFracGap
}
}
};
/////////////////////////////////////////////////////
// l1_to_l0_b
/////////////////////////////////////////////////////
// Partial specialization for vector
template <ArchType ArchTag, typename DataType, bool IsTransPose>
struct l1_to_l0_b<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(
l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0));
};
};
template <ArchType ArchTag>
struct l1_to_l0_b<ArchTag, int8_t, true, DataFormat::NZ, DataFormat::ZN> {
using HardwareParams = HardwareInfo<ArchTag>;
using DataType = int8_t;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) {
AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE],
l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE],
AscendC::LoadData2dTransposeParams(0, // startIndexIn
kPartCeil / BLOCK_SIZE, // repeatTimesIn
nTileCeil / BLOCK_SIZE, // srcStrideIn
1, // dstGapIn
0, // dstfracGapIn
0) // addrModeIn
);
}
};
};
// Partial specialization for no transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_b<ArchTag, DataType, false, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) {
AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE],
l1Tensor[i * kSrcStride * FRACTAL_SIZE],
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(nTileCeil / BLOCK_SIZE), // repeat
nSrcStride, // srcStride
0, // sid
nDstStride - 1, // dstStride
true, // transpose
0)); // addrCalMode
}
};
};
// Partial specialization for transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_b<ArchTag, DataType, true, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(
l0Tensor,
l1Tensor,
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat
1, // srcStride
0, // sid
0, // dstStride
false, // transpose
0)); // addr_cal_mode_t
};
};
template <>
struct l1_to_l0_b<ArchType::ASCEND_V220, int8_t, false, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 16
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<int8_t> l0Tensor,
AscendC::LocalTensor<int8_t> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint64_t i = 0; i < kPartCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
AscendC::LoadDataWithTranspose(
l0Tensor[i * kDstStride * FRACTAL_SIZE], // dstLocalTensor
l1Tensor[i * kSrcStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // srcLocalTensor
AscendC::LoadData2dTransposeParams(0, // baseIdx
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(nTileCeil)), // repeat
nSrcStride / NUM_FRACTAL_PER_ITER, // srcStride
1, // dstGap
0)); // dstFracGap
}
};
};

View File

@@ -0,0 +1,44 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_ub
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType>
struct l1_to_ub {
__aicore__ l1_to_ub(AscendC::LocalTensor<DataType> ubTensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
/////////////////////////////////////////////////////
// ub_to_l1
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType>
struct ub_to_l1 {
__aicore__ ub_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::LocalTensor<DataType> ubTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};