[Kernel] add custom op GmmSwigluQuantWeightNzTensorList (#3804)
### What this PR does / why we need it? This PR introduces support for adding custom CANN `aclnn` ops to `vllm-ascend`, allowing users to define and use their own custom operators. Key changes include: - Building and installing custom ops into the `vllm-ascend`-specified directory - Binding the `aclnn` op interface to the `torch.ops._C_ascend` module - Enabling invocation of these ops within `vllm-ascend` This PR includes a sample custom op: `aclnnGroupedMatmulSwigluQuantWeightNzTensorList`, which is adapted from the CANN operator [`aclnnGroupedMatmulSwigluQuantWeightNZ`](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/API/aolapi/context/aclnnGroupedMatmulSwigluQuantWeightNZ.md). Its input parameters `weight` and `weight_scale` now accept `list[torch.Tensor]` (i.e., `at::TensorList`). ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.11.2 --------- Signed-off-by: QianChenxi <chenxi.qian.cq@outlook.com>
This commit is contained in:
121
csrc/utils/inc/kernel/dropmask.h
Normal file
121
csrc/utils/inc/kernel/dropmask.h
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file dropmask.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef DROPMASK_H
|
||||
#define DROPMASK_H
|
||||
|
||||
#include "util.h"
|
||||
|
||||
using AscendC::DROPOUT_MODE_BIT_MISALIGN;
|
||||
using AscendC::DropOutShapeInfo;
|
||||
using AscendC::DropOut;
|
||||
|
||||
struct DropMaskInfo {
|
||||
// for compute dropout mask offset
|
||||
// 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值
|
||||
int64_t n2G; // n2 * g
|
||||
int64_t gSize; // g
|
||||
int64_t s1Size; // s1
|
||||
int64_t s2Size; // s2
|
||||
int64_t gOutIdx; // g out index
|
||||
int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset
|
||||
int64_t n2OutIdx; // n out index
|
||||
int64_t s1OutIdx; // s1 out index ===s1oIdx
|
||||
int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx
|
||||
int64_t s1BaseSize; // S1基本块大小
|
||||
int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize
|
||||
int64_t s2StartIdx; // s2 start index
|
||||
int64_t s2Idx; // s2 index =====s2LoopCount
|
||||
int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio
|
||||
|
||||
// for copy in dropout mask
|
||||
uint32_t s1CopySize;
|
||||
uint32_t s2CopySize;
|
||||
int64_t s2TotalSize;
|
||||
|
||||
// for compute dropout mask
|
||||
uint32_t firstAxis;
|
||||
uint32_t lstAxis;
|
||||
uint32_t maskLstAxis;
|
||||
int64_t vecCoreOffset = 0;
|
||||
float keepProb;
|
||||
|
||||
bool boolMode;
|
||||
};
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
// boidx * n2 * g* s1 * s2
|
||||
int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G;
|
||||
// n2oIdx * g * s1 *s2
|
||||
int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// goIdx * s1 * s2
|
||||
int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size;
|
||||
// s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size
|
||||
int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset +
|
||||
dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size;
|
||||
// s2StartIdx + s2index * s2BaseNratioSize
|
||||
int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize;
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasDrop>
|
||||
__aicore__ inline void CopyInDropMask(LocalTensor<uint8_t>&dstTensor, GlobalTensor<uint8_t>& srcBoolTensor,
|
||||
GlobalTensor<uint8_t>& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
int64_t dropMaskOffset = ComputeDropOffset<hasDrop>(dropMaskInfo);
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
} else {
|
||||
Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1,
|
||||
dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasDrop>
|
||||
__aicore__ inline void ComputeDropMask(LocalTensor<T>& dstTensor, LocalTensor<T>& srcTensor,
|
||||
LocalTensor<uint8_t>& dropoutBuffer, LocalTensor<uint8_t>& tmpDropBuffer, DropMaskInfo &dropMaskInfo)
|
||||
{
|
||||
if constexpr (hasDrop == true) {
|
||||
DropOutShapeInfo dropOutShapeInfo;
|
||||
dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis;
|
||||
dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis;
|
||||
|
||||
if (unlikely(dropMaskInfo.boolMode)) {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes;
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes;
|
||||
if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) {
|
||||
DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
} else {
|
||||
DropOut<T, false, DROPOUT_MODE_BIT_MISALIGN>(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer,
|
||||
dropMaskInfo.keepProb, dropOutShapeInfo);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DROPMASK_H
|
||||
483
csrc/utils/inc/kernel/pse.h
Normal file
483
csrc/utils/inc/kernel/pse.h
Normal file
@@ -0,0 +1,483 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file pse.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_SCORE_PSE_H
|
||||
#define FLASH_ATTENTION_SCORE_PSE_H
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "util.h"
|
||||
|
||||
constexpr static int64_t pseS1S2 = 0;
|
||||
constexpr static int64_t pse1S2 = 1;
|
||||
constexpr static int64_t pseSlopeBn = 2;
|
||||
constexpr static int64_t pseSlopeN = 3;
|
||||
|
||||
constexpr static uint8_t pseEncodeALibiS2Full = 0x11;
|
||||
|
||||
enum class PseTypeEnum {
|
||||
PSE_OUTER_MUL_ADD_TYPE = 0, // default
|
||||
PSE_OUTER_ADD_MUL_TYPE,
|
||||
PSE_INNER_MUL_ADD_TYPE,
|
||||
PSE_INNER_MUL_ADD_SQRT_TYPE,
|
||||
PSE_INVALID_TYPE
|
||||
};
|
||||
|
||||
struct PseInfo {
|
||||
int64_t blockCount;
|
||||
int64_t bSSOffset; // boidx * s1 * s2
|
||||
int64_t boIdx;
|
||||
int64_t gSize;
|
||||
int64_t goIdx;
|
||||
int64_t loopIdx;
|
||||
int64_t n2G;
|
||||
int64_t n2oIdx;
|
||||
int64_t pseBSize;
|
||||
int64_t pseS1Size; // for alibi
|
||||
int64_t pseS2ComputeSize; // for alibi, do not need assignment
|
||||
int64_t pseS2Size; // for alibi
|
||||
uint32_t pseShapeType;
|
||||
int64_t readS2Size; // for alibi, do not need assignment
|
||||
int64_t s1BaseSize;
|
||||
int64_t s1Size;
|
||||
int64_t s1oIdx;
|
||||
int64_t s2AlignedSize;
|
||||
int64_t s2BaseNratioSize;
|
||||
int64_t s2LoopCount;
|
||||
int64_t s2RealSize;
|
||||
int64_t s2Size;
|
||||
int64_t s2SizeAcc; // accumulated sum of s2 size
|
||||
int64_t s2StartIdx;
|
||||
int64_t vec1S1BaseSize;
|
||||
int64_t vec1S1RealSize;
|
||||
uint32_t pseEncodeType; // for distinguish alibi
|
||||
uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt
|
||||
int64_t pseAlibiBaseS1;
|
||||
int64_t pseAlibiBaseS2;
|
||||
int64_t qStartIdx;
|
||||
int64_t kvStartIdx;
|
||||
int64_t vecCoreOffset = 0;
|
||||
bool needCast;
|
||||
bool align8 = false;
|
||||
bool pseEndogenous = false;
|
||||
};
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInCommon(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize,
|
||||
int32_t alignedS2Size)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B
|
||||
dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap
|
||||
if (actualS2Len * dtypeSize % blockBytes == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap
|
||||
DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte
|
||||
dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen);
|
||||
dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = false;
|
||||
DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, bool hasPse>
|
||||
__aicore__ inline void DataCopyInAlign8(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
|
||||
int64_t s1Size, int64_t s2Size, int64_t actualS2Len)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int32_t dtypeSize = sizeof(INPUT_T);
|
||||
if (dtypeSize == 0){
|
||||
return;
|
||||
}
|
||||
int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize);
|
||||
DataCopyInCommon<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, s2Size,
|
||||
actualS2Len, dtypeSize, alignedS2Size);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
dst = BroadcastAdd(src0, src1)
|
||||
src0 shape: (s1, s2)
|
||||
src1 shape: (1, s2)
|
||||
dst shape: (s1, s2)
|
||||
*/
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void BroadcastAdd(const LocalTensor<T> &src0Tensor, const LocalTensor<T> &src1Tensor,
|
||||
int64_t src0Offset, int32_t src1Size, int32_t repeatTimes)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
/* Total data number of single step should be smaller than 256bytes.
|
||||
* If larger, we need to do add multiple times. */
|
||||
int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数
|
||||
int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量
|
||||
BinaryRepeatParams binaryRepeatParams;
|
||||
binaryRepeatParams.src0BlkStride = 1;
|
||||
binaryRepeatParams.src0RepStride = src1Size / blockSize;
|
||||
binaryRepeatParams.src1BlkStride = 1;
|
||||
binaryRepeatParams.src1RepStride = 0;
|
||||
binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride;
|
||||
binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride;
|
||||
|
||||
for (int32_t j = 0; j < innerLoop; j++) {
|
||||
auto innerOffset = j * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
if (innerRemain > 0) {
|
||||
auto innerOffset = innerLoop * repeatMaxSize;
|
||||
auto ubOffset = src0Offset + innerOffset;
|
||||
Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes,
|
||||
binaryRepeatParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor<T> &pseUb,
|
||||
const LocalTensor<T> &dstTensor, uint32_t pseShapeType)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) {
|
||||
Add(dstTensor, dstTensor, pseUb, computeSize);
|
||||
} else {
|
||||
/* Total repeated times should be <= repeatMaxTimes. If larger,
|
||||
* we need to do multiple inner loops. */
|
||||
int32_t s1OuterLoop = s1Size / repeatMaxTimes;
|
||||
int32_t s1OuterRemain = s1Size % repeatMaxTimes;
|
||||
for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) {
|
||||
int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes);
|
||||
}
|
||||
if (s1OuterRemain > 0) {
|
||||
int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size;
|
||||
BroadcastAdd<T, hasPse>(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool hasPse> __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = 0;
|
||||
int64_t s1Offset = 0;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t gOffset = 0;
|
||||
if (pseInfo.pseShapeType == pseS1S2) {
|
||||
// b, n2, g, s1, s2
|
||||
bOffset = pseInfo.bSSOffset * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size;
|
||||
s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size;
|
||||
} else if (pseInfo.pseShapeType == pse1S2) {
|
||||
// b, n2, g, 1, s2
|
||||
bOffset = pseInfo.s2SizeAcc * pseInfo.n2G;
|
||||
n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size;
|
||||
gOffset = pseInfo.goIdx * pseInfo.s2Size;
|
||||
}
|
||||
if (pseInfo.pseBSize == 1) {
|
||||
bOffset = 0;
|
||||
}
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <LayOutTypeEnum layOutType, bool hasPse> __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size;
|
||||
int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
int64_t m = 0;
|
||||
int64_t k = 0;
|
||||
if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
|
||||
int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size;
|
||||
if (row >= threshold) {
|
||||
m = row - threshold;
|
||||
k = column;
|
||||
} else {
|
||||
m = row % pseInfo.pseS1Size;
|
||||
k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m);
|
||||
}
|
||||
} else {
|
||||
int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size;
|
||||
int64_t posVal = row - column - threshold;
|
||||
if (threshold >= 0) {
|
||||
if (posVal >= 0) {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
} else {
|
||||
m = 0;
|
||||
k = -posVal;
|
||||
}
|
||||
} else {
|
||||
m = posVal;
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
int64_t s1Offset = m * pseInfo.pseS2Size;
|
||||
int64_t s2Offset = k;
|
||||
pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k);
|
||||
pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size);
|
||||
|
||||
return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse> __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
// Alibi编码只计算下三角
|
||||
if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
(pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <=
|
||||
pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
int64_t offset = PseAlibiComputeOffset<layOutType, hasPse>(pseInfo);
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize,
|
||||
pseInfo.readS2Size, pseInfo.pseS2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size,
|
||||
pseInfo.pseS2Size, alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCopyIn(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, GlobalTensor<half> &alibiGm, PseInfo &pseInfo,
|
||||
int64_t alignedSize = 16) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
|
||||
DataCopyIn<half, hasPse>(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize,
|
||||
pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize);
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
|
||||
if (pseInfo.needCast) {
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseSlopeCast(LocalTensor<T> &dstTensor, LocalTensor<half> &helpTensor,
|
||||
__gm__ uint8_t *pseSlope, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
int64_t bOffset = 0;
|
||||
int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize;
|
||||
int64_t gOffset = pseInfo.goIdx;
|
||||
|
||||
if (pseInfo.pseShapeType == pseSlopeBn) {
|
||||
bOffset = pseInfo.boIdx * pseInfo.n2G;
|
||||
}
|
||||
int64_t offset = bOffset + n2Offset + gOffset;
|
||||
int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize;
|
||||
Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset +
|
||||
pseInfo.loopIdx * pseInfo.vec1S1BaseSize;
|
||||
int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize;
|
||||
|
||||
float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx);
|
||||
|
||||
Adds(dstTensor, dstTensor, posShift, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
Abs(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
float slopes = ((__gm__ T *)pseSlope)[offset] * -1;
|
||||
if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
Sqrt(dstTensor, dstTensor, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
Muls(dstTensor, dstTensor, slopes, computeSize);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename INPUT_T, typename T, LayOutTypeEnum layOutType, bool hasPse>
|
||||
__aicore__ inline void PseCopyIn(LocalTensor<T> &dstTensor, LocalTensor<INPUT_T> &tmpTensor,
|
||||
GlobalTensor<INPUT_T> &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCopyIn<INPUT_T, T, layOutType, hasPse>(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize);
|
||||
}
|
||||
int64_t offset = PseComputeOffset<hasPse>(pseInfo);
|
||||
int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) :
|
||||
pseInfo.vec1S1RealSize;
|
||||
|
||||
if constexpr (IsSameType<INPUT_T, T>::value) {
|
||||
if (!pseInfo.align8){
|
||||
DataCopyIn<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize,
|
||||
pseInfo.s2Size, alignedSize);
|
||||
} else {
|
||||
DataCopyInAlign8<INPUT_T, hasPse>(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
DataCopyIn<INPUT_T, hasPse>(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size,
|
||||
alignedSize);
|
||||
if (pseInfo.needCast) {
|
||||
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
|
||||
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
|
||||
Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseAlibiCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (!NeedPseAlibiCompute<hasPse>(pseInfo)) {
|
||||
return;
|
||||
}
|
||||
Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool hasPse>
|
||||
__aicore__ inline void PseCompute(LocalTensor<T> &dstTensor, LocalTensor<T> &pseTensor, PseInfo &pseInfo)
|
||||
{
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) {
|
||||
return PseAlibiCompute<T, hasPse>(dstTensor, pseTensor, pseInfo);
|
||||
}
|
||||
int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn ||
|
||||
pseInfo.pseShapeType == pseSlopeN)
|
||||
? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize
|
||||
: pseInfo.s2AlignedSize;
|
||||
PseBroadcastAdd<T, hasPse>(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor,
|
||||
dstTensor, pseInfo.pseShapeType);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool hasPse>
|
||||
__aicore__ inline void PseInnerAlibiCreate(GlobalTensor<half> &dstTensor, LocalTensor<half> &helpTensor, PseInfo &pseInfo) {
|
||||
if constexpr (hasPse == true) {
|
||||
if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) {
|
||||
return;
|
||||
}
|
||||
event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
|
||||
event_t eventIdMte3ToS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S));
|
||||
event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
|
||||
float tmpValue = -1.0;
|
||||
|
||||
for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) {
|
||||
CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
|
||||
DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2);
|
||||
SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
|
||||
SetFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
WaitFlag<HardEvent::MTE3_S>(eventIdMte3ToS);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
144
csrc/utils/inc/kernel/util.h
Normal file
144
csrc/utils/inc/kernel/util.h
Normal file
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file util.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef FLASH_ATTENTION_UTIL_H
|
||||
#define FLASH_ATTENTION_UTIL_H
|
||||
|
||||
constexpr int32_t blockBytes = 32;
|
||||
constexpr int32_t byteBitRatio = 8;
|
||||
constexpr int64_t prefixAttenMaskDownHeight = 1024;
|
||||
constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T)
|
||||
constexpr static int32_t repeatMaxBytes = 256;
|
||||
constexpr static int32_t repeatMaxTimes = 255;
|
||||
constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T)
|
||||
|
||||
using AscendC::LocalTensor;
|
||||
using AscendC::GlobalTensor;
|
||||
using AscendC::DataFormat;
|
||||
using AscendC::ShapeInfo;
|
||||
using AscendC::DataCopyParams;
|
||||
using AscendC::DataCopyPadParams;
|
||||
using AscendC::BinaryRepeatParams;
|
||||
using AscendC::IsSameType;
|
||||
using AscendC::HardEvent;
|
||||
using AscendC::SetFlag;
|
||||
using AscendC::WaitFlag;
|
||||
|
||||
enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5};
|
||||
|
||||
namespace math {
|
||||
template <typename T> __aicore__ inline T Ceil(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T> __aicore__ inline T Align(T a, T b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b * b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 CeilDiv(T1 a, T2 b)
|
||||
{
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Max(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (a) : (b);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__aicore__ inline T1 Min(T1 a, T2 b)
|
||||
{
|
||||
return (a > b) ? (b) : (a);
|
||||
}
|
||||
|
||||
__aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {s1Size, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(s1Size * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = s1Size;
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0
|
||||
dataCopyParams.dstStride = 1;
|
||||
alignedSize = blockBytes;
|
||||
alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes;
|
||||
}
|
||||
if (totalS2Size % alignedSize == 0) {
|
||||
dataCopyParams.blockLen = alignedS2Size / blockBytes;
|
||||
dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = s2Size;
|
||||
dataCopyParams.srcStride = totalS2Size - s2Size;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes);
|
||||
dataCopyPadParams.paddingValue = 1;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
|
||||
int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize,
|
||||
int64_t alignedSize = blockBytes)
|
||||
{
|
||||
uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize;
|
||||
uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size};
|
||||
dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
|
||||
dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size);
|
||||
DataCopyParams dataCopyParams;
|
||||
dataCopyParams.blockCount = batchSize * s1BaseSize;
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes);
|
||||
dataCopyParams.dstStride = 0;
|
||||
if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) {
|
||||
dataCopyParams.srcStride =
|
||||
(s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes;
|
||||
DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams);
|
||||
} else {
|
||||
dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio);
|
||||
dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio;
|
||||
DataCopyPadParams dataCopyPadParams;
|
||||
dataCopyPadParams.isPad = true;
|
||||
dataCopyPadParams.rightPadding = 0;
|
||||
dataCopyPadParams.paddingValue = 0;
|
||||
DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline int32_t Align(int32_t shape)
|
||||
{
|
||||
int32_t alignFactor = 16;
|
||||
int32_t alignedSize = CeilDiv<int32_t, int32_t>(shape, alignFactor) * alignFactor;
|
||||
return alignedSize;
|
||||
}
|
||||
|
||||
#endif // FLASH_ATTENTION_UTIL_H
|
||||
Reference in New Issue
Block a user