add dispath_ffn_combine_bf16 (#5866)
### What this PR does / why we need it?
add dispath_ffn_combine_bf16
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -0,0 +1,208 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
#include "catlass/epilogue/block/block_epilogue.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
// float scale, dequant per expert
|
||||
template <
|
||||
uint32_t UB_STAGES_,
|
||||
class CType_,
|
||||
class LayoutPerTokenScale_,
|
||||
class DType_,
|
||||
class TileCopy_
|
||||
>
|
||||
class BlockEpilogue <
|
||||
EpilogueAtlasA2PerTokenDequant<UB_STAGES_>,
|
||||
CType_,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||
DType_,
|
||||
TileCopy_
|
||||
> {
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequant<UB_STAGES_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(
|
||||
(std::is_same_v<ElementC, half> || std::is_same_v<ElementC, bfloat16_t>) &&
|
||||
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
|
||||
"The element type template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
static_assert(
|
||||
std::is_same_v<LayoutC, layout::RowMajor> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
struct Params {
|
||||
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
||||
int32_t EP;
|
||||
int32_t expertPerRank;
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 4096;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr int32_t blockN = 12000;
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementC);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementD);
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += blockN * sizeof(float);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
MatrixCoord const &shapeC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
||||
AscendC::GlobalTensor<ElementD> const &gmD
|
||||
)
|
||||
{
|
||||
uint32_t blockM = shapeC.row();
|
||||
uint32_t blockN = shapeC.column();
|
||||
|
||||
uint32_t tileLoops = blockM;
|
||||
|
||||
for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) {
|
||||
auto gmTileC = gmC[loopIdx * blockN];
|
||||
auto &ubC = ubCList[ubListId];
|
||||
auto &ubCFp32 = ubCFp32List[ubListId];
|
||||
auto &ubMul = ubMulList[ubListId];
|
||||
auto &ubD = ubDList[ubListId];
|
||||
auto gmTileD = gmD[loopIdx * blockN];
|
||||
LayoutC layoutUbC{1, blockN};
|
||||
|
||||
// Move C from GM workspace to UB
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
// Cast C to FP32 in UB
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
// Get per-token scale from row loopIdx of gmPerTokenScale
|
||||
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
// Multiply FP32 C by the per-token scale
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
// Cast the muls result back to fp16/bf16
|
||||
LayoutD layoutUbD{1, blockN};
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubMulList[UB_STAGES];
|
||||
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
|
||||
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
|
||||
@@ -0,0 +1,402 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
|
||||
// float scale, dequant per expert
|
||||
template <
|
||||
uint32_t UB_STAGES_,
|
||||
class CType_,
|
||||
class LayoutPerTokenScale_,
|
||||
class DType_,
|
||||
class TileElemWiseMuls_,
|
||||
class TileCopy_
|
||||
>
|
||||
class BlockEpilogue <
|
||||
EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>,
|
||||
CType_,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||
DType_,
|
||||
TileElemWiseMuls_,
|
||||
TileCopy_
|
||||
> {
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
// Check data infos
|
||||
static_assert(
|
||||
(std::is_same_v<ElementC, half> || std::is_same_v<ElementC, bfloat16_t>) &&
|
||||
(std::is_same_v<ElementD, float> || std::is_same_v<ElementD, int8_t> || std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
|
||||
"The element type template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
static_assert(
|
||||
std::is_same_v<LayoutC, layout::RowMajor> &&
|
||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
|
||||
"The layout template parameters of BlockEpilogue are wrong"
|
||||
);
|
||||
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm<ArchTag, Gemm::GemmType<ElementPerTokenScale, LayoutPerTokenScale>>;
|
||||
|
||||
struct Params {
|
||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||
LayoutPerTokenScale layoutPerTokenScale{};
|
||||
__gm__ ElementD *ptrD{nullptr};
|
||||
LayoutD layoutD{};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_
|
||||
) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_),
|
||||
ptrD(ptrD_), layoutD(layoutD_) {}
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
size_t ubOffset = 0;
|
||||
int32_t eventVMTE2 = 0;
|
||||
int32_t eventMTE2V = 0;
|
||||
int32_t eventMTE3V = 0;
|
||||
int32_t eventVMTE3 = 0;
|
||||
constexpr uint32_t blockN = 4096;
|
||||
constexpr uint32_t ChunkTileLen = blockN / 2;
|
||||
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementC);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += blockN * sizeof(ElementD);
|
||||
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += blockN * sizeof(float);
|
||||
ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += ChunkTileLen * sizeof(float);
|
||||
ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += ChunkTileLen * sizeof(float);
|
||||
ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += HalfChunkTileLen * sizeof(float);
|
||||
ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<int32_t>();
|
||||
ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<half>();
|
||||
|
||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
|
||||
ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
||||
}
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateParams(Params const ¶ms_)
|
||||
{
|
||||
params = params_;
|
||||
}
|
||||
// 每个tile就是1*7168,每个block是一个expert的所有token=[group[i], 7168]
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
MatrixCoord const &shapeC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale1,
|
||||
AscendC::GlobalTensor<ElementD> const &gmD,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale2,
|
||||
|
||||
uint32_t epilogueCoreNum = 40,
|
||||
Callback &&callback = Callback{}
|
||||
)
|
||||
{
|
||||
callback();
|
||||
uint32_t blockM = shapeC.row();
|
||||
uint32_t blockN = shapeC.column();
|
||||
|
||||
uint32_t tileLoops = blockM;
|
||||
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
|
||||
uint32_t subblockNum = get_block_num() * 2;
|
||||
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
|
||||
|
||||
if (subblockIdx < moveDataCoreNum) {
|
||||
return;
|
||||
}
|
||||
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
|
||||
|
||||
uint32_t perCoreData = blockM / epilogueCoreNum;
|
||||
uint32_t remainderData = blockM % epilogueCoreNum;
|
||||
|
||||
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
|
||||
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
|
||||
|
||||
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
|
||||
|
||||
uint32_t ChunkTileLen = blockN / 2;
|
||||
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
|
||||
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
|
||||
|
||||
auto gmTileC = gmC[loopIdx * blockN];
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
auto &ubD = ubDList[ubListId];
|
||||
|
||||
auto &ubCFp32 = ubCFp32List[ubListId];
|
||||
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
|
||||
auto &ubAbs = ubCFp32ChunkNAbsList[ubListId];
|
||||
// auto &ubMax = ubCFp32ChunkNMaxList[ubListId];
|
||||
auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId];
|
||||
auto &ubOutputTmp = ubAbs;
|
||||
auto &sharedUbTmpBuffer = ubReduceMax;
|
||||
auto &ubQuantS32 = ubQuantS32List[ubListId];
|
||||
auto &ubQuantF16 = ubQuantF16List[ubListId];
|
||||
|
||||
auto gmTileD = gmD[loopIdx * ChunkTileLen];
|
||||
LayoutC layoutUbC{1, blockN};
|
||||
|
||||
// 把C从GM workspace搬到UB
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
// 在UB上做把C cast成FP32
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
// 获取pertoken scale值,gmPerTokenScale的第loopIdx行
|
||||
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
// pertoken scale值与FP32的C做Muls乘法
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
//swiglue计算过程
|
||||
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
//TODO除的时候是否会对之后的数据有影响;
|
||||
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
|
||||
|
||||
//quant过程,两种方式区别;
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::ReduceMax<float>(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
||||
|
||||
//TODO两种计算方法的效率比较
|
||||
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||
|
||||
auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx;
|
||||
ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SetDeqScale(static_cast<half>(1.0));
|
||||
AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
|
||||
AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
|
||||
// AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
LayoutD layoutUbD{1, ChunkTileLen};
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
|
||||
if(tasksForIdx > 0){
|
||||
LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx};
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
|
||||
|
||||
copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
MatrixCoord const &shapeC,
|
||||
AscendC::GlobalTensor<ElementD> const &gmD,
|
||||
uint32_t epilogueCoreNum = 40,
|
||||
Callback &&callback = Callback{}
|
||||
)
|
||||
{
|
||||
callback();
|
||||
uint32_t blockM = shapeC.row();
|
||||
uint32_t blockN = shapeC.column();
|
||||
|
||||
uint32_t tileLoops = blockM;
|
||||
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
|
||||
//uint32_t subblockIdx = get_block_idx() * 2 + get_subblockid();
|
||||
|
||||
uint32_t subblockNum = get_block_num() * 2;
|
||||
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
|
||||
|
||||
if (subblockIdx < moveDataCoreNum) {
|
||||
return;
|
||||
}
|
||||
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
|
||||
|
||||
|
||||
uint32_t perCoreData = blockM / epilogueCoreNum;
|
||||
uint32_t remainderData = blockM % epilogueCoreNum;
|
||||
|
||||
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
|
||||
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
|
||||
|
||||
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
|
||||
|
||||
uint32_t ChunkTileLen = blockN / 2;
|
||||
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||
|
||||
|
||||
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
|
||||
|
||||
auto gmTileC = gmC[loopIdx * blockN];
|
||||
|
||||
auto &ubC = ubCList[ubListId];
|
||||
auto &ubD = ubDList[ubListId];
|
||||
|
||||
auto &ubCFp32 = ubCFp32List[ubListId];
|
||||
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
|
||||
|
||||
auto gmTileD = gmD[loopIdx * ChunkTileLen];
|
||||
LayoutC layoutUbC{1, blockN};
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||
|
||||
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
|
||||
AscendC::Cast(ubD, ubCFp32ChunkN, AscendC::RoundMode::CAST_ROUND, ChunkTileLen);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
|
||||
|
||||
LayoutD layoutUbD{1, ChunkTileLen};
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
||||
// copyUbToGmD(gmTileD, ubCFp32ChunkN, layoutUbD, layoutUbD);
|
||||
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
Params params;
|
||||
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
|
||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||
|
||||
uint32_t ubListId{0};
|
||||
|
||||
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNList[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNAbsList[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubCFp32ChunkNMaxList[UB_STAGES];
|
||||
AscendC::LocalTensor<int32_t> ubQuantS32List[UB_STAGES];
|
||||
AscendC::LocalTensor<half> ubQuantF16List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubPerTokenScaleOutput;
|
||||
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
CopyUbToGmDequantScale copyUbToGmDequantScale;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Epilogue::Block
|
||||
|
||||
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
|
||||
@@ -0,0 +1,330 @@
|
||||
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
|
||||
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/epilogue/dispatch_policy.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/matrix_coord.hpp"
|
||||
#include "catlass/layout/layout.hpp"
|
||||
#include "catlass/detail/callback.hpp"
|
||||
|
||||
#include "hccl_shmem.hpp"
|
||||
#include "layout3d.hpp"
|
||||
|
||||
namespace Catlass::Epilogue::Block {
|
||||
template <
|
||||
uint32_t UB_STAGES_,
|
||||
class CType_,
|
||||
class LayoutPerTokenScale_,
|
||||
class DType_,
|
||||
class TileCopy_
|
||||
>
|
||||
class BlockEpilogue <
|
||||
EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>,
|
||||
CType_,
|
||||
Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||
DType_,
|
||||
TileCopy_
|
||||
> {
|
||||
public:
|
||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
|
||||
// Data infos
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using ElementPerTokenScale = float;
|
||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||
using ElementD = typename DType_::Element;
|
||||
using LayoutD = typename DType_::Layout;
|
||||
|
||||
//using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::RowMajor>>;
|
||||
using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::VectorLayout>>;
|
||||
// Tile copy
|
||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
||||
|
||||
struct Params {
|
||||
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
||||
int32_t EP;
|
||||
int32_t expertPerRank;
|
||||
int32_t n2;
|
||||
LayoutC layoutC;
|
||||
int32_t n0;
|
||||
int32_t rank;
|
||||
HcclShmem shmem;
|
||||
int32_t offsetD;
|
||||
|
||||
CATLASS_DEVICE
|
||||
Params() {};
|
||||
CATLASS_DEVICE
|
||||
Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_,
|
||||
LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) :
|
||||
ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_),
|
||||
expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_),
|
||||
shmem(shmem_), offsetD(offsetD_)
|
||||
{}
|
||||
};
|
||||
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||
{
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
|
||||
|
||||
//ub:192KB
|
||||
n0 = params.n0;
|
||||
size_t ubOffset = 0;
|
||||
for(int32_t i = 0; i < 2; i++) {
|
||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||
ubOffset += max_len * sizeof(ElementC);
|
||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||
ubOffset += max_len * sizeof(ElementD);
|
||||
ubFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += max_len * sizeof(float);
|
||||
scaleUbList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
||||
ubOffset += (max_len / n0) * sizeof(float);
|
||||
source_scale_offset[i] = -1;
|
||||
}
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
|
||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
||||
is_ping = true;
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Finalize()
|
||||
{
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
|
||||
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
~BlockEpilogue()
|
||||
{
|
||||
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
|
||||
GemmCoord& blockCoord,
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
auto &ubCFp32 = ubFp32List[is_ping];
|
||||
auto &scaleUb = scaleUbList[is_ping];
|
||||
// auto &ubOutFp32 = ubOutFp32List[is_ping];
|
||||
|
||||
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
|
||||
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
|
||||
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
|
||||
layout::VectorLayout scaleLauout{actualBlockShape.m()};
|
||||
if (source_scale_offset[event_id] != gmScaleOffset) {
|
||||
source_scale_offset[event_id] = gmScaleOffset;
|
||||
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
|
||||
|
||||
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
|
||||
float scale = scaleUb(row);
|
||||
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
|
||||
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
|
||||
|
||||
int32_t lenTile = actualBlockShape.m();
|
||||
int32_t stTile = blockCoord.m();
|
||||
int32_t edTile = stTile + lenTile;
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
if (stRankInExpert >= edTile) {
|
||||
break;
|
||||
}
|
||||
else if (edRankInExpert <= stTile) {
|
||||
continue;
|
||||
}
|
||||
int32_t stData = max(stRankInExpert, stTile);
|
||||
int32_t edData = min(edRankInExpert, edTile);
|
||||
uint32_t lenData = edData - stData;
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
|
||||
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
|
||||
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator() (
|
||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||
GemmCoord& blockCoord,
|
||||
GemmCoord& actualBlockShape,
|
||||
int32_t groupIdx,
|
||||
int32_t preSrcExpertSum,
|
||||
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
|
||||
uint32_t *mPreSumBeforeRank
|
||||
){
|
||||
is_ping = !is_ping;
|
||||
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
|
||||
|
||||
auto &ubC = ubCList[is_ping];
|
||||
auto &ubD = ubDList[is_ping];
|
||||
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
|
||||
auto gmTileC = gmC[gmCOffset];
|
||||
|
||||
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
|
||||
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
|
||||
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
|
||||
int32_t lenTile = actualBlockShape.m();
|
||||
int32_t stTile = blockCoord.m();
|
||||
int32_t edTile = stTile + lenTile;
|
||||
int32_t preSumRankInExpert = 0;
|
||||
int32_t tileOffset = 0;
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
|
||||
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
|
||||
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
|
||||
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
|
||||
int32_t stRankInExpert = preSumRankInExpert;
|
||||
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
|
||||
preSumRankInExpert += lenRankInExpert;
|
||||
if (stRankInExpert >= edTile) {
|
||||
break;
|
||||
}
|
||||
else if (edRankInExpert <= stTile) {
|
||||
continue;
|
||||
}
|
||||
int32_t stData = max(stRankInExpert, stTile);
|
||||
int32_t edData = min(edRankInExpert, edTile);
|
||||
uint32_t lenData = edData - stData;
|
||||
if (lenData <= 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t dstOffsetInExpert = 0;
|
||||
if (stTile > stRankInExpert) {
|
||||
dstOffsetInExpert = stTile - stRankInExpert;
|
||||
}
|
||||
AscendC::GlobalTensor<ElementD> gmRemotePeer;
|
||||
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
|
||||
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
|
||||
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
|
||||
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
|
||||
auto gmTileD = gmRemotePeer[gmDstOffset];
|
||||
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
|
||||
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
|
||||
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
|
||||
tileOffset += lenData;
|
||||
}
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
Params params;
|
||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||
AscendC::LocalTensor<float> ubFp32List[UB_STAGES];
|
||||
AscendC::LocalTensor<float> scaleUbList[UB_STAGES];
|
||||
int32_t source_scale_offset[UB_STAGES];
|
||||
|
||||
int32_t max_len = 8 * 32 / 4 * 128;
|
||||
int32_t n0;
|
||||
bool is_ping = false;
|
||||
|
||||
|
||||
int32_t repeat = 128;
|
||||
|
||||
|
||||
CopyGmToUbC copyGmToUbC;
|
||||
CopyUbToGmD copyUbToGmD;
|
||||
|
||||
CopyScaleGmToUb copyScaleGmToUb;
|
||||
AscendC::GlobalTensor<int32_t> tokenPerExpert;
|
||||
Layout3D tokenPerExpertLayout;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,502 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
#include "catlass/coord.hpp"
|
||||
#include "catlass/gemm_coord.hpp"
|
||||
#include "catlass/gemm/dispatch_policy.hpp"
|
||||
#include "catlass/gemm/helper.hpp"
|
||||
#include "dispatch_policy_custom.hpp"
|
||||
|
||||
|
||||
namespace Catlass::Gemm::Block {
|
||||
|
||||
template<AscendC::HardEvent event>
|
||||
__aicore__ inline void SyncFlagFunc(int32_t eventID)
|
||||
{
|
||||
AscendC::SetFlag<event>(eventID);
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
template <
|
||||
uint32_t PRELOAD_STAGES_,
|
||||
uint32_t L1_STAGES_,
|
||||
uint32_t L0A_STAGES_,
|
||||
uint32_t L0B_STAGES_,
|
||||
uint32_t L0C_STAGES_,
|
||||
bool ENABLE_UNIT_FLAG_,
|
||||
bool ENABLE_SHUFFLE_K_,
|
||||
class L1TileShape_,
|
||||
class L0TileShape_,
|
||||
class AType_,
|
||||
class BType_,
|
||||
class CType_,
|
||||
class BiasType_,
|
||||
class TileCopy_,
|
||||
class TileMmad_
|
||||
>
|
||||
struct BlockMmad <
|
||||
MmadAtlasA2PreloadAsyncFixpipe<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
>,
|
||||
L1TileShape_,
|
||||
L0TileShape_,
|
||||
AType_,
|
||||
BType_,
|
||||
CType_,
|
||||
BiasType_,
|
||||
TileCopy_,
|
||||
TileMmad_
|
||||
> {
|
||||
public:
|
||||
// Type Aliases
|
||||
using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
>;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using L1TileShape = L1TileShape_;
|
||||
using L0TileShape = L0TileShape_;
|
||||
using ElementA = typename AType_::Element;
|
||||
using LayoutA = typename AType_::Layout;
|
||||
using ElementB = typename BType_::Element;
|
||||
using LayoutB = typename BType_::Layout;
|
||||
using ElementC = typename CType_::Element;
|
||||
using LayoutC = typename CType_::Layout;
|
||||
using TileMmad = TileMmad_;
|
||||
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
|
||||
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
|
||||
using CopyGmToL1S = Gemm::Tile::CopyGmToL1<ArchTag, Gemm::GemmType<uint64_t, layout::VectorLayout>>;
|
||||
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
|
||||
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
|
||||
|
||||
using ElementAccumulator =
|
||||
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
|
||||
using CopyL0CToGm = typename std::conditional<
|
||||
std::is_same_v<ElementA, int8_t>,
|
||||
Gemm::Tile::CopyL0CToGm<ArchTag, ElementAccumulator, CType_, Gemm::Tile::ScaleGranularity::PER_CHANNEL>,
|
||||
typename TileCopy_::CopyL0CToGm
|
||||
>::type;
|
||||
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
|
||||
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
|
||||
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
|
||||
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
|
||||
using LayoutCInL0 = layout::zN;
|
||||
|
||||
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
|
||||
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
|
||||
|
||||
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
|
||||
static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES;
|
||||
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
|
||||
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
|
||||
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
|
||||
|
||||
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
|
||||
|
||||
// L1 tile size
|
||||
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
|
||||
static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t);
|
||||
// L0 tile size
|
||||
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
|
||||
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
|
||||
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
|
||||
|
||||
// Check LayoutC
|
||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
|
||||
|
||||
// Check L1TileShape
|
||||
static_assert(
|
||||
(std::is_same_v<ElementA, int8_t>
|
||||
? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE
|
||||
: (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE),
|
||||
"L1TileShape exceeding the L1 space for the given data type"
|
||||
);
|
||||
|
||||
// Check L0TileShape
|
||||
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
|
||||
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
|
||||
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
|
||||
|
||||
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
|
||||
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
|
||||
|
||||
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(
|
||||
L1TileShape::M, L1TileShape::K);
|
||||
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(
|
||||
L1TileShape::K, L1TileShape::N);
|
||||
|
||||
CATLASS_DEVICE
|
||||
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
|
||||
{
|
||||
syncGroupIdx = 0;
|
||||
InitL1(resource, l1BufAddrStart);
|
||||
InitL0A(resource);
|
||||
InitL0B(resource);
|
||||
InitL0C(resource);
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
~BlockMmad()
|
||||
{
|
||||
SynchronizeBlock();
|
||||
for (uint32_t i = 0; i < L1_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(
|
||||
AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
|
||||
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
|
||||
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
|
||||
AscendC::GlobalTensor<uint64_t> const &gmBlockS, layout::VectorLayout const &layoutScale,
|
||||
GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0
|
||||
)
|
||||
{
|
||||
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
|
||||
|
||||
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
|
||||
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
|
||||
|
||||
uint32_t startTileIdx = 0;
|
||||
if constexpr (ENABLE_SHUFFLE_K) {
|
||||
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
|
||||
}
|
||||
|
||||
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
|
||||
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ?
|
||||
(startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount);
|
||||
|
||||
uint32_t kActual = (kTileIdx < kTileCount - 1) ?
|
||||
L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
|
||||
|
||||
// Emission load instruction from GM to L1
|
||||
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
|
||||
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
|
||||
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
|
||||
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
|
||||
// Load first matrix A tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
|
||||
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
|
||||
copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
|
||||
// Load first matrix B tile from GM to L1
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
|
||||
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
|
||||
copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);
|
||||
|
||||
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
|
||||
if (preloadCount == PRELOAD_STAGES) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
}
|
||||
|
||||
// Store the current load status
|
||||
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ?
|
||||
(l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
|
||||
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
|
||||
l1TileMmadParams.l1ListId = l1ListId;
|
||||
l1TileMmadParams.mRound = mRound;
|
||||
l1TileMmadParams.nRound = nRound;
|
||||
l1TileMmadParams.kActual = kActual;
|
||||
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
|
||||
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
|
||||
l1TileMmadParams.flag = flag;
|
||||
if (kLoopIdx == kTileCount - 1) {
|
||||
l1TileMmadParams.gmBlockC = gmBlockC;
|
||||
l1TileMmadParams.gmBlockS = gmBlockS;
|
||||
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
|
||||
l1TileMmadParams.layoutScale = layoutScale;
|
||||
l1TileMmadParams.syncLoopIdx = syncLoopIdx;
|
||||
}
|
||||
|
||||
if (preloadCount < PRELOAD_STAGES) {
|
||||
++preloadCount;
|
||||
} else {
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
}
|
||||
l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void SynchronizeBlock()
|
||||
{
|
||||
while (preloadCount > 0) {
|
||||
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
|
||||
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
|
||||
--preloadCount;
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void Finalize(int32_t target, int32_t flag = 0)
|
||||
{
|
||||
for(;syncGroupIdx <= target; syncGroupIdx++) {
|
||||
int32_t flagId = syncGroupIdx / 8 + flag;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
|
||||
}
|
||||
}
|
||||
private:
|
||||
struct L1TileMmadParams {
|
||||
uint32_t l1ListId;
|
||||
uint32_t mRound;
|
||||
uint32_t nRound;
|
||||
uint32_t kActual;
|
||||
bool isKLoopFirst;
|
||||
bool isKLoopLast;
|
||||
AscendC::GlobalTensor<ElementC> gmBlockC;
|
||||
AscendC::GlobalTensor<uint64_t> gmBlockS;
|
||||
LayoutC layoutCInGm;
|
||||
layout::VectorLayout layoutScale;
|
||||
int32_t syncLoopIdx;
|
||||
int32_t flag;
|
||||
|
||||
CATLASS_DEVICE
|
||||
L1TileMmadParams() = default;
|
||||
};
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
|
||||
{
|
||||
uint32_t l1AOffset = l1BufAddrStart;
|
||||
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES;
|
||||
|
||||
for (uint32_t i = 0; i < L1_STAGES; ++i) {
|
||||
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
|
||||
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
|
||||
l1AEventList[i] = i;
|
||||
l1BEventList[i] = i + L1_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
|
||||
}
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
|
||||
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0A(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
|
||||
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
|
||||
l0AEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0B(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
|
||||
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
|
||||
l0BEventList[i] = i + L0A_STAGES;
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void InitL0C(Arch::Resource<ArchTag> &resource)
|
||||
{
|
||||
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
|
||||
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
|
||||
l0CEventList[i] = i;
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
|
||||
}
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void L1TileMmad(L1TileMmadParams const ¶ms)
|
||||
{
|
||||
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
|
||||
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
|
||||
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
|
||||
auto &l1ATensor = l1ATensorList[params.l1ListId];
|
||||
auto &l1BTensor = l1BTensorList[params.l1ListId];
|
||||
|
||||
auto &l0CTensor = l0CTensorList[l0CListId];
|
||||
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
|
||||
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopFirst) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
|
||||
uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ?
|
||||
L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
|
||||
|
||||
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
|
||||
uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ?
|
||||
L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
|
||||
|
||||
auto &l0ATile = l0ATensorList[l0AListId];
|
||||
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
|
||||
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
|
||||
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
if ((mPartIdx == 0) && (kPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1ListId]);
|
||||
}
|
||||
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
|
||||
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1ListId]);
|
||||
}
|
||||
|
||||
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
|
||||
uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ?
|
||||
L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
|
||||
|
||||
auto &l0BTile = l0BTensorList[l0BListId];
|
||||
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
|
||||
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
|
||||
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
if ((kPartIdx == 0) && (nPartIdx == 0)) {
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1ListId]);
|
||||
}
|
||||
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
|
||||
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1ListId]);
|
||||
}
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
|
||||
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
|
||||
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
|
||||
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
|
||||
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
|
||||
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
|
||||
// If the unit flag is enabled, the unit flag is set according to the calculation progress
|
||||
uint8_t unitFlag = 0b00;
|
||||
if constexpr (ENABLE_UNIT_FLAG) {
|
||||
if (params.isKLoopLast &&
|
||||
(mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
|
||||
unitFlag = 0b11;
|
||||
} else {
|
||||
unitFlag = 0b10;
|
||||
}
|
||||
}
|
||||
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
|
||||
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
|
||||
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
|
||||
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.isKLoopLast) {
|
||||
auto layoutCInGm = params.layoutCInGm;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
auto layoutScale = params.layoutScale;
|
||||
auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1)));
|
||||
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_FIX>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_FIX>(0);
|
||||
}
|
||||
if constexpr (!ENABLE_UNIT_FLAG) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0);
|
||||
} else {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
|
||||
}
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
|
||||
} else {
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11);
|
||||
} else {
|
||||
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
|
||||
}
|
||||
}
|
||||
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
|
||||
if constexpr (std::is_same_v<ElementA, int8_t>) {
|
||||
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
|
||||
}
|
||||
Finalize(params.syncLoopIdx, params.flag);
|
||||
}
|
||||
}
|
||||
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
|
||||
AscendC::LocalTensor<uint64_t> l1STensor;
|
||||
int32_t syncGroupIdx;
|
||||
int32_t l1AEventList[L1_STAGES];
|
||||
int32_t l1BEventList[L1_STAGES];
|
||||
uint32_t l1ListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
|
||||
int32_t l0AEventList[L0A_STAGES];
|
||||
uint32_t l0AListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
|
||||
int32_t l0BEventList[L0B_STAGES];
|
||||
uint32_t l0BListId{0};
|
||||
|
||||
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
|
||||
int32_t l0CEventList[L0C_STAGES_];
|
||||
uint32_t l0CListId{0};
|
||||
|
||||
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
|
||||
uint32_t l1TileMmadParamsId{0};
|
||||
uint32_t preloadCount{0};
|
||||
|
||||
TileMmad tileMmad;
|
||||
CopyGmToL1A copyGmToL1A;
|
||||
CopyGmToL1B copyGmToL1B;
|
||||
CopyGmToL1S copyGmToL1S;
|
||||
CopyL1ToL0A copyL1ToL0A;
|
||||
CopyL1ToL0B copyL1ToL0B;
|
||||
CopyL0CToGm copyL0CToGm;
|
||||
};
|
||||
|
||||
} // namespace Catlass::Gemm::Block
|
||||
|
||||
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
|
||||
@@ -0,0 +1,9 @@
|
||||
|
||||
#ifndef CONST_ARGS_HPP
|
||||
#define CONST_ARGS_HPP
|
||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||
constexpr static int32_t CACHE_LINE = 512;
|
||||
constexpr static int32_t RESET_VAL = 0xffff;
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||
#endif
|
||||
@@ -0,0 +1,40 @@
|
||||
#ifndef COPY_GM_TO_L1_CUSTOM_HPP
|
||||
#define COPY_GM_TO_L1_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm::Tile {
|
||||
/// Partial specialization for nZ in and nZ out.
|
||||
template <
|
||||
class ArchTag,
|
||||
class Element
|
||||
>
|
||||
struct CopyGmToL1<ArchTag, Gemm::GemmType<Element, layout::VectorLayout>> {
|
||||
using LayoutDst = layout::VectorLayout;
|
||||
using LayoutSrc = layout::VectorLayout;
|
||||
|
||||
static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4
|
||||
|
||||
// Mehtods
|
||||
|
||||
CATLASS_DEVICE
|
||||
CopyGmToL1() {};
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(
|
||||
AscendC::LocalTensor<Element> const &dstTensor,
|
||||
AscendC::GlobalTensor<Element> const &srcTensor,
|
||||
LayoutDst const &layoutDst, LayoutSrc const &layoutSrc)
|
||||
{
|
||||
uint32_t blockCount = 1;
|
||||
uint32_t blockLen = CeilDiv<ELE_NUM_PER_C0>(layoutSrc.shape(0));
|
||||
|
||||
AscendC::DataCopyParams repeatParams;
|
||||
|
||||
repeatParams.blockCount = blockCount;
|
||||
repeatParams.blockLen = blockLen;
|
||||
repeatParams.srcStride = 0;
|
||||
repeatParams.dstStride = 0;
|
||||
AscendC::DataCopy(dstTensor, srcTensor, repeatParams);
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif // COPY_GM_TO_L1_CUSTOM_HPP
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
#define COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm::Tile {
|
||||
template <
|
||||
class ElementAccumulator_,
|
||||
class ElementDst_,
|
||||
bool ReluEnable_
|
||||
>
|
||||
struct CopyL0CToGm<Catlass::Arch::AtlasA2,
|
||||
ElementAccumulator_,
|
||||
Gemm::GemmType<ElementDst_, layout::RowMajor>,
|
||||
ScaleGranularity::PER_CHANNEL,
|
||||
ReluEnable_>
|
||||
{
|
||||
using ArchTag = Catlass::Arch::AtlasA2;
|
||||
using ElementDst = ElementDst_;
|
||||
using ElementSrc = ElementAccumulator_;
|
||||
using LayoutSrc = Catlass::layout::zN;
|
||||
using LayoutDst = Catlass::layout::RowMajor;
|
||||
static constexpr auto quantPre = CopyL0CToGmQuantMode<ArchTag, ElementSrc, ElementDst,
|
||||
ScaleGranularity::PER_CHANNEL>::VALUE;
|
||||
static constexpr auto reluEn = ReluEnable_;
|
||||
|
||||
CATLASS_DEVICE
|
||||
void operator()(AscendC::GlobalTensor<ElementDst> const &dst, AscendC::LocalTensor<ElementSrc> const &src, AscendC::LocalTensor<uint64_t> cbufWorkspace,
|
||||
LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::FixpipeParamsV220 intriParams;
|
||||
|
||||
// Fixpipe layout information
|
||||
intriParams.nSize = dstLayout.shape(1);
|
||||
intriParams.mSize = dstLayout.shape(0);
|
||||
intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0);
|
||||
intriParams.dstStride = dstLayout.stride(0);
|
||||
|
||||
// Fixpipe auxiliary arguments
|
||||
intriParams.quantPre = quantPre;
|
||||
intriParams.reluEn = reluEn;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
|
||||
// Call AscendC Fixpipe
|
||||
AscendC::Fixpipe<ElementDst, ElementSrc, AscendC::CFG_ROW_MAJOR>(dst, src, cbufWorkspace, intriParams);
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif // COPY_L0C_TO_GM_CUSTOM_HPP
|
||||
@@ -0,0 +1,53 @@
|
||||
#ifndef DISPATH_POLICY_CUSTOM_HPP
|
||||
#define DISPATH_POLICY_CUSTOM_HPP
|
||||
|
||||
namespace Catlass::Gemm {
|
||||
template <bool ENABLE_UNIT_FLAG_ = false, bool ENABLE_SHUFFLE_K_ = false>
|
||||
struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 {
|
||||
static constexpr uint32_t STAGES = 2;
|
||||
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
|
||||
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
|
||||
};
|
||||
|
||||
template <uint32_t PRELOAD_STAGES_, uint32_t L1_STAGES_, uint32_t L0A_STAGES_, uint32_t L0B_STAGES_,
|
||||
uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
|
||||
struct MmadAtlasA2PreloadAsyncFixpipe :
|
||||
public MmadAtlasA2PreloadAsync<
|
||||
PRELOAD_STAGES_,
|
||||
L1_STAGES_,
|
||||
L0A_STAGES_,
|
||||
L0B_STAGES_,
|
||||
L0C_STAGES_,
|
||||
ENABLE_UNIT_FLAG_,
|
||||
ENABLE_SHUFFLE_K_
|
||||
> {
|
||||
};
|
||||
}
|
||||
|
||||
namespace Catlass::Epilogue {
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2UnQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2PerTokenDequantQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
|
||||
template <uint32_t UB_STAGES_>
|
||||
struct EpilogueAtlasA2PerTokenDequantV2 {
|
||||
using ArchTag = Arch::AtlasA2;
|
||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
||||
};
|
||||
}
|
||||
#endif // DISPATH_POLICY_CUSTOM_HPP
|
||||
@@ -0,0 +1,16 @@
|
||||
#ifndef GET_TENSOR_ADDR_HPP
|
||||
#define GET_TENSOR_ADDR_HPP
|
||||
#include "kernel_operator.h"
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) {
|
||||
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
|
||||
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
|
||||
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
|
||||
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
|
||||
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
|
||||
}
|
||||
|
||||
#endif // GET_TENSOR_ADDR_HPP
|
||||
195
csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp
Normal file
195
csrc/dispatch_ffn_combine_bf16/op_kernel/utils/hccl_shmem.hpp
Normal file
@@ -0,0 +1,195 @@
|
||||
#ifndef SYNC_UTIL_HPP
|
||||
#define SYNC_UTIL_HPP
|
||||
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "const_args.hpp"
|
||||
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#ifndef HCCL_COMM
|
||||
#include "shmem_api.h"
|
||||
using namespace AscendC::HcclContextDef;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
|
||||
|
||||
FORCE_INLINE_AICORE void AicSyncAll() {
|
||||
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
|
||||
AscendC::CrossCoreWaitFlag<0x0>(8);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
|
||||
*((__gm__ T *)addr) = val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
|
||||
return *((__gm__ T *)cache);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
|
||||
using namespace AscendC;
|
||||
GlobalTensor<uint8_t> global;
|
||||
global.SetGlobalBuffer(addr);
|
||||
|
||||
// Important: add hint to avoid dcci being optimized by compiler
|
||||
__asm__ __volatile__("");
|
||||
DataCacheCleanAndInvalid<uint8_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(global);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
gm_dcci((__gm__ uint8_t *)sig_addr);
|
||||
if (*sig_addr == cmp_val) {
|
||||
return *sig_addr;
|
||||
}
|
||||
if (*sig_addr == cmp_val + 1) {
|
||||
return *sig_addr;
|
||||
}
|
||||
} while (true);
|
||||
return -1;
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||
do {
|
||||
AscendC::LocalTensor<int32_t> ub;
|
||||
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||
ub.address_.bufferAddr = 0;
|
||||
AscendC::GlobalTensor<int32_t> sig;
|
||||
sig.SetGlobalBuffer(sig_addr);
|
||||
AscendC::DataCopy(ub, sig, 8);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||
if (ub(0) != cmp_val) {
|
||||
return;
|
||||
}
|
||||
} while (true);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
class HcclShmem {
|
||||
public:
|
||||
#ifdef HCCL_COMM
|
||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
||||
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
|
||||
|
||||
m_rank = WinContext_->localUsrRankId;
|
||||
m_rankSize = WinContext_->rankSize;
|
||||
m_segmentSize = WinContext_->winSize;
|
||||
for (int i = 0; i < m_rankSize; i++) {
|
||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
||||
}
|
||||
}
|
||||
#else
|
||||
FORCE_INLINE_AICORE
|
||||
HcclShmem(){
|
||||
m_segmentSize = SHMEM_MEM;
|
||||
}
|
||||
FORCE_INLINE_AICORE
|
||||
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
|
||||
symmetricPtr = symmetricPtr_;
|
||||
m_rank = rank;
|
||||
m_rankSize = rankSize;
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() () const {
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[m_rank];
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator() (int32_t index) const {
|
||||
#ifdef HCCL_COMM
|
||||
return m_ptrArray[index];
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
|
||||
#ifdef HCCL_COMM
|
||||
if (offset < 0 || offset >= m_segmentSize) {
|
||||
return nullptr;
|
||||
}
|
||||
if (rankId < 0 || rankId >= m_rankSize) {
|
||||
return nullptr;
|
||||
}
|
||||
return m_ptrArray[rankId] + offset;
|
||||
#else
|
||||
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
size_t SegmentSize() const {
|
||||
return m_segmentSize;
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
int32_t RankSize() const {
|
||||
return m_rankSize;
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
~HcclShmem() {
|
||||
}
|
||||
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
void CrossRankSync() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
__gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset;
|
||||
__gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048;
|
||||
int count = gm_load(sync_base) + 1;
|
||||
int vec_id = AscendC::GetBlockIdx();
|
||||
int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation();
|
||||
for(int i = vec_id; i < m_rankSize; i += vec_size) {
|
||||
__gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16;
|
||||
gm_store(sync_remote, count);
|
||||
gm_dcci((__gm__ uint8_t*)sync_remote);
|
||||
auto sync_check = sync_counter + i * 16;
|
||||
gm_signal_wait_until_eq_for_barrier(sync_check, count);
|
||||
}
|
||||
|
||||
AscendC::SyncAll<true>();
|
||||
gm_store(sync_base, count);
|
||||
}
|
||||
|
||||
FORCE_INLINE_AICORE
|
||||
__gm__ int32_t* SyncBaseAddr() {
|
||||
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
|
||||
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
|
||||
}
|
||||
|
||||
private:
|
||||
GM_ADDR symmetricPtr;
|
||||
int32_t m_rank;
|
||||
int32_t m_rankSize;
|
||||
size_t m_segmentSize;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
20
csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp
Normal file
20
csrc/dispatch_ffn_combine_bf16/op_kernel/utils/layout3d.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef LAYOUT_3D_HPP
|
||||
#define LAYOUT_3D_HPP
|
||||
#include "kernel_operator.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
class Layout3D {
|
||||
int64_t strides[2];
|
||||
public:
|
||||
CATLASS_DEVICE
|
||||
Layout3D() {}
|
||||
CATLASS_DEVICE
|
||||
Layout3D(int64_t stride0, int64_t stride1) {
|
||||
strides[0] = stride0;
|
||||
strides[1] = stride1;
|
||||
}
|
||||
CATLASS_DEVICE
|
||||
int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) {
|
||||
return dim0 * strides[0] + dim1 * strides[1] + dim2;
|
||||
}
|
||||
};
|
||||
#endif // LAYOUT_3D_HPP
|
||||
@@ -0,0 +1,25 @@
|
||||
#ifndef SELECT_HELPER_HPP
|
||||
#define SELECT_HELPER_HPP
|
||||
|
||||
#include "catlass/layout/layout.hpp"
|
||||
using namespace AscendC;
|
||||
using namespace Catlass;
|
||||
|
||||
template <typename Layout, typename ElementType, typename = void>
|
||||
struct LayoutBInitializer {
|
||||
CATLASS_DEVICE
|
||||
static Layout create(uint32_t k, uint32_t n) {
|
||||
return Layout{k, n};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout, typename ElementType>
|
||||
struct LayoutBInitializer<Layout, ElementType,
|
||||
std::enable_if_t<std::is_same_v<Layout, layout::zN>>
|
||||
> {
|
||||
CATLASS_DEVICE
|
||||
static Layout create(uint32_t k, uint32_t n) {
|
||||
return Layout::template MakeLayout<ElementType>(k, n);
|
||||
}
|
||||
};
|
||||
#endif // SELECT_HELPER_HPP
|
||||
Reference in New Issue
Block a user