/** * Copyright (c) 2024 Huawei Technologies Co., Ltd. * This file is a part of the CANN Open Software. * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). * Please refer to the License for details. You may not use this file except in compliance with the License. * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ /*! * \file dropmask.h * \brief */ #ifndef DROPMASK_H #define DROPMASK_H #include "util.h" using AscendC::DROPOUT_MODE_BIT_MISALIGN; using AscendC::DropOutShapeInfo; using AscendC::DropOut; struct DropMaskInfo { // for compute dropout mask offset // 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值 int64_t n2G; // n2 * g int64_t gSize; // g int64_t s1Size; // s1 int64_t s2Size; // s2 int64_t gOutIdx; // g out index int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset int64_t n2OutIdx; // n out index int64_t s1OutIdx; // s1 out index ===s1oIdx int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx int64_t s1BaseSize; // S1基本块大小 int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize int64_t s2StartIdx; // s2 start index int64_t s2Idx; // s2 index =====s2LoopCount int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio // for copy in dropout mask uint32_t s1CopySize; uint32_t s2CopySize; int64_t s2TotalSize; // for compute dropout mask uint32_t firstAxis; uint32_t lstAxis; uint32_t maskLstAxis; int64_t vecCoreOffset = 0; float keepProb; bool boolMode; }; template __aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo) { if constexpr (hasDrop == true) { // boidx * n2 * g* s1 * s2 int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G; // n2oIdx * g * s1 *s2 int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size; // goIdx * s1 * s2 int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size; // s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset + dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size; // s2StartIdx + s2index * s2BaseNratioSize int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize; return bOffset + n2Offset + gOffset + s1Offset + s2Offset; } else { return 0; } } template __aicore__ inline void CopyInDropMask(LocalTensor&dstTensor, GlobalTensor& srcBoolTensor, GlobalTensor& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes) { if constexpr (hasDrop == true) { int64_t dropMaskOffset = ComputeDropOffset(dropMaskInfo); if (unlikely(dropMaskInfo.boolMode)) { BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset, dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); } else { Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1, dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); } return; } } template __aicore__ inline void ComputeDropMask(LocalTensor& dstTensor, LocalTensor& srcTensor, LocalTensor& dropoutBuffer, LocalTensor& tmpDropBuffer, DropMaskInfo &dropMaskInfo) { if constexpr (hasDrop == true) { DropOutShapeInfo dropOutShapeInfo; dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis; dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis; if (unlikely(dropMaskInfo.boolMode)) { dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes; DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); } else { dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes; if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) { DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); } else { DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); } } return; } } #endif // DROPMASK_H