add dispath_ffn_combine_bf16 (#5866)

### What this PR does / why we need it?
add dispath_ffn_combine_bf16

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
guanguan0308
2026-01-21 09:30:30 +08:00
committed by GitHub
parent bec8641876
commit 1ed9524763
45 changed files with 8420 additions and 1 deletions

View File

@@ -0,0 +1,208 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/epilogue/block/block_epilogue.hpp"
namespace Catlass::Epilogue::Block {
// float scale, dequant per expert
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequant<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequant<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(
(std::is_same_v<ElementC, half> || std::is_same_v<ElementC, bfloat16_t>) &&
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
"The element type template parameters of BlockEpilogue are wrong"
);
static_assert(
std::is_same_v<LayoutC, layout::RowMajor> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong"
);
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 4096;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
constexpr int32_t blockN = 12000;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += blockN * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += blockN * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += blockN * sizeof(float);
}
}
CATLASS_DEVICE
void Finalize()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
MatrixCoord const &shapeC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
AscendC::GlobalTensor<ElementD> const &gmD
)
{
uint32_t blockM = shapeC.row();
uint32_t blockN = shapeC.column();
uint32_t tileLoops = blockM;
for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) {
auto gmTileC = gmC[loopIdx * blockN];
auto &ubC = ubCList[ubListId];
auto &ubCFp32 = ubCFp32List[ubListId];
auto &ubMul = ubMulList[ubListId];
auto &ubD = ubDList[ubListId];
auto gmTileD = gmD[loopIdx * blockN];
LayoutC layoutUbC{1, blockN};
// Move C from GM workspace to UB
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
// Cast C to FP32 in UB
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
// Get per-token scale from row loopIdx of gmPerTokenScale
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
// Multiply FP32 C by the per-token scale
AscendC::PipeBarrier<PIPE_V>();
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
AscendC::PipeBarrier<PIPE_V>();
// Cast the muls result back to fp16/bf16
LayoutD layoutUbD{1, blockN};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
AscendC::LocalTensor<float> ubMulList[UB_STAGES];
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP

View File

@@ -0,0 +1,402 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
namespace Catlass::Epilogue::Block {
// float scale, dequant per expert
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileElemWiseMuls_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileElemWiseMuls_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(
(std::is_same_v<ElementC, half> || std::is_same_v<ElementC, bfloat16_t>) &&
(std::is_same_v<ElementD, float> || std::is_same_v<ElementD, int8_t> || std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
"The element type template parameters of BlockEpilogue are wrong"
);
static_assert(
std::is_same_v<LayoutC, layout::RowMajor> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> && std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong"
);
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm<ArchTag, Gemm::GemmType<ElementPerTokenScale, LayoutPerTokenScale>>;
struct Params {
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_
) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_), layoutD(layoutD_) {}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
constexpr uint32_t blockN = 4096;
constexpr uint32_t ChunkTileLen = blockN / 2;
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += blockN * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += blockN * sizeof(ElementD);
ubCFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += blockN * sizeof(float);
ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += ChunkTileLen * sizeof(float);
ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += ChunkTileLen * sizeof(float);
ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += HalfChunkTileLen * sizeof(float);
ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<int32_t>();
ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast<half>();
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
}
CATLASS_DEVICE
void Finalize()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
// 每个tile就是1*7168每个block是一个expert的所有token=[group[i], 7168]
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
MatrixCoord const &shapeC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale1,
AscendC::GlobalTensor<ElementD> const &gmD,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale2,
uint32_t epilogueCoreNum = 40,
Callback &&callback = Callback{}
)
{
callback();
uint32_t blockM = shapeC.row();
uint32_t blockN = shapeC.column();
uint32_t tileLoops = blockM;
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
uint32_t subblockNum = get_block_num() * 2;
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
if (subblockIdx < moveDataCoreNum) {
return;
}
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
uint32_t perCoreData = blockM / epilogueCoreNum;
uint32_t remainderData = blockM % epilogueCoreNum;
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
uint32_t ChunkTileLen = blockN / 2;
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
auto gmTileC = gmC[loopIdx * blockN];
auto &ubC = ubCList[ubListId];
auto &ubD = ubDList[ubListId];
auto &ubCFp32 = ubCFp32List[ubListId];
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
auto &ubAbs = ubCFp32ChunkNAbsList[ubListId];
// auto &ubMax = ubCFp32ChunkNMaxList[ubListId];
auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId];
auto &ubOutputTmp = ubAbs;
auto &sharedUbTmpBuffer = ubReduceMax;
auto &ubQuantS32 = ubQuantS32List[ubListId];
auto &ubQuantF16 = ubQuantF16List[ubListId];
auto gmTileD = gmD[loopIdx * ChunkTileLen];
LayoutC layoutUbC{1, blockN};
// 把C从GM workspace搬到UB
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
// 在UB上做把C cast成FP32
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
// 获取pertoken scale值gmPerTokenScale的第loopIdx行
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
// pertoken scale值与FP32的C做Muls乘法
AscendC::PipeBarrier<PIPE_V>();
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
AscendC::PipeBarrier<PIPE_V>();
//swiglue计算过程
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
//TODO除的时候是否会对之后的数据有影响
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
//quant过程两种方式区别
AscendC::PipeBarrier<PIPE_V>();
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::ReduceMax<float>(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
//TODO两种计算方法的效率比较
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx;
ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f);
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SetDeqScale(static_cast<half>(1.0));
AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen);
// AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
LayoutD layoutUbD{1, ChunkTileLen};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
if(tasksForIdx > 0){
LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx};
AscendC::SetFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE3>(EVENT_ID0);
copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2);
}
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
MatrixCoord const &shapeC,
AscendC::GlobalTensor<ElementD> const &gmD,
uint32_t epilogueCoreNum = 40,
Callback &&callback = Callback{}
)
{
callback();
uint32_t blockM = shapeC.row();
uint32_t blockN = shapeC.column();
uint32_t tileLoops = blockM;
uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num();
//uint32_t subblockIdx = get_block_idx() * 2 + get_subblockid();
uint32_t subblockNum = get_block_num() * 2;
uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum;
if (subblockIdx < moveDataCoreNum) {
return;
}
uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum;
uint32_t perCoreData = blockM / epilogueCoreNum;
uint32_t remainderData = blockM % epilogueCoreNum;
uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData;
uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData);
uint32_t alignedPerCoreData = RoundUp<BYTE_PER_BLK / sizeof(ElementPerTokenScale)>(perCoreData + 1);
uint32_t ChunkTileLen = blockN / 2;
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) {
auto gmTileC = gmC[loopIdx * blockN];
auto &ubC = ubCList[ubListId];
auto &ubD = ubDList[ubListId];
auto &ubCFp32 = ubCFp32List[ubListId];
auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId];
auto gmTileD = gmD[loopIdx * ChunkTileLen];
LayoutC layoutUbC{1, blockN};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDVMTE3List[ubListId]);
AscendC::Cast(ubD, ubCFp32ChunkN, AscendC::RoundMode::CAST_ROUND, ChunkTileLen);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDMTE3VList[ubListId]);
LayoutD layoutUbD{1, ChunkTileLen};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
// copyUbToGmD(gmTileD, ubCFp32ChunkN, layoutUbD, layoutUbD);
copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32List[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNList[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNAbsList[UB_STAGES];
AscendC::LocalTensor<float> ubCFp32ChunkNMaxList[UB_STAGES];
AscendC::LocalTensor<int32_t> ubQuantS32List[UB_STAGES];
AscendC::LocalTensor<half> ubQuantF16List[UB_STAGES];
AscendC::LocalTensor<float> ubPerTokenScaleOutput;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyUbToGmDequantScale copyUbToGmDequantScale;
};
} // namespace Catlass::Epilogue::Block
#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP

View File

@@ -0,0 +1,330 @@
#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "hccl_shmem.hpp"
#include "layout3d.hpp"
namespace Catlass::Epilogue::Block {
template <
uint32_t UB_STAGES_,
class CType_,
class LayoutPerTokenScale_,
class DType_,
class TileCopy_
>
class BlockEpilogue <
EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>,
CType_,
Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_,
TileCopy_
> {
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2<UB_STAGES_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
//using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::RowMajor>>;
using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub<ArchTag, Gemm::GemmType<float, layout::VectorLayout>>;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
struct Params {
__gm__ int32_t *ptrTokenPerExpert{nullptr};
int32_t EP;
int32_t expertPerRank;
int32_t n2;
LayoutC layoutC;
int32_t n0;
int32_t rank;
HcclShmem shmem;
int32_t offsetD;
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_,
LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) :
ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_),
expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_),
shmem(shmem_), offsetD(offsetD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
//ub:192KB
n0 = params.n0;
size_t ubOffset = 0;
for(int32_t i = 0; i < 2; i++) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += max_len * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += max_len * sizeof(ElementD);
ubFp32List[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += max_len * sizeof(float);
scaleUbList[i] = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += (max_len / n0) * sizeof(float);
source_scale_offset[i] = -1;
}
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert));
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
is_ping = true;
}
CATLASS_DEVICE
void Finalize()
{
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID1);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(EVENT_ID3);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID1);
}
CATLASS_DEVICE
~BlockEpilogue()
{
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
AscendC::GlobalTensor<ElementPerTokenScale> const &gmPerTokenScale,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
auto &ubCFp32 = ubFp32List[is_ping];
auto &scaleUb = scaleUbList[is_ping];
// auto &ubOutFp32 = ubOutFp32List[is_ping];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id); //for debug
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id);
AscendC::Cast<float, ElementC, false>(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4});
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m();
layout::VectorLayout scaleLauout{actualBlockShape.m()};
if (source_scale_offset[event_id] != gmScaleOffset) {
source_scale_offset[event_id] = gmScaleOffset;
copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(event_id_2);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(event_id_2); // 注意必须是MTE2_S不能是MTE2_V否则会读到0造成乱码
AscendC::PipeBarrier<PIPE_V>();
for (int32_t row = 0; row < actualBlockShape.m(); ++row) {
float scale = scaleUb(row);
Muls<float, false>(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8});
}
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(event_id);
AscendC::Cast<ElementD, float, false>(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8});
AscendC::SetFlag<AscendC::HardEvent::S_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(event_id_2);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(event_id);
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(event_id); //for debug
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(event_id);
}
CATLASS_DEVICE
void operator() (
AscendC::GlobalTensor<ElementC> const &gmC,
GemmCoord& blockCoord,
GemmCoord& actualBlockShape,
int32_t groupIdx,
int32_t preSrcExpertSum,
AscendC::GlobalTensor<int32_t> preSumBeforeRank,
uint32_t *mPreSumBeforeRank
){
is_ping = !is_ping;
auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1;
auto &ubC = ubCList[is_ping];
auto &ubD = ubDList[is_ping];
int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n();
auto gmTileC = gmC[gmCOffset];
LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2};
LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0};
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(event_id); //for debug
copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
int32_t lenTile = actualBlockShape.m();
int32_t stTile = blockCoord.m();
int32_t edTile = stTile + lenTile;
int32_t preSumRankInExpert = 0;
int32_t tileOffset = 0;
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(event_id); //for debug
for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) {
int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx));
int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * 16);
int32_t stRankInExpert = preSumRankInExpert;
int32_t edRankInExpert = stRankInExpert + lenRankInExpert;
preSumRankInExpert += lenRankInExpert;
if (stRankInExpert >= edTile) {
break;
}
else if (edRankInExpert <= stTile) {
continue;
}
int32_t stData = max(stRankInExpert, stTile);
int32_t edData = min(edRankInExpert, edTile);
uint32_t lenData = edData - stData;
if (lenData <= 0){
continue;
}
uint32_t dstOffsetInExpert = 0;
if (stTile > stRankInExpert) {
dstOffsetInExpert = stTile - stRankInExpert;
}
AscendC::GlobalTensor<ElementD> gmRemotePeer;
__gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx);
gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr));
MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset + mPreSumBeforeRank[dstEpIdx], blockCoord.n()};
int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset);
auto gmTileD = gmRemotePeer[gmDstOffset];
LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2};
LayoutC layoutUB2{lenData, actualBlockShape.n(), n0};
copyUbToGmD(gmTileD, ubC[tileOffset * n0], layoutGM2, layoutUB2);
tileOffset += lenData;
}
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(event_id);
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
AscendC::LocalTensor<float> ubFp32List[UB_STAGES];
AscendC::LocalTensor<float> scaleUbList[UB_STAGES];
int32_t source_scale_offset[UB_STAGES];
int32_t max_len = 8 * 32 / 4 * 128;
int32_t n0;
bool is_ping = false;
int32_t repeat = 128;
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
CopyScaleGmToUb copyScaleGmToUb;
AscendC::GlobalTensor<int32_t> tokenPerExpert;
Layout3D tokenPerExpertLayout;
};
}
#endif

View File

@@ -0,0 +1,502 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/helper.hpp"
#include "dispatch_policy_custom.hpp"
namespace Catlass::Gemm::Block {
template<AscendC::HardEvent event>
__aicore__ inline void SyncFlagFunc(int32_t eventID)
{
AscendC::SetFlag<event>(eventID);
AscendC::WaitFlag<event>(eventID);
}
template <
uint32_t PRELOAD_STAGES_,
uint32_t L1_STAGES_,
uint32_t L0A_STAGES_,
uint32_t L0B_STAGES_,
uint32_t L0C_STAGES_,
bool ENABLE_UNIT_FLAG_,
bool ENABLE_SHUFFLE_K_,
class L1TileShape_,
class L0TileShape_,
class AType_,
class BType_,
class CType_,
class BiasType_,
class TileCopy_,
class TileMmad_
>
struct BlockMmad <
MmadAtlasA2PreloadAsyncFixpipe<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
>,
L1TileShape_,
L0TileShape_,
AType_,
BType_,
CType_,
BiasType_,
TileCopy_,
TileMmad_
> {
public:
// Type Aliases
using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
>;
using ArchTag = typename DispatchPolicy::ArchTag;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using ElementA = typename AType_::Element;
using LayoutA = typename AType_::Layout;
using ElementB = typename BType_::Element;
using LayoutB = typename BType_::Layout;
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using TileMmad = TileMmad_;
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
using CopyGmToL1S = Gemm::Tile::CopyGmToL1<ArchTag, Gemm::GemmType<uint64_t, layout::VectorLayout>>;
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
using ElementAccumulator =
typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
using CopyL0CToGm = typename std::conditional<
std::is_same_v<ElementA, int8_t>,
Gemm::Tile::CopyL0CToGm<ArchTag, ElementAccumulator, CType_, Gemm::Tile::ScaleGranularity::PER_CHANNEL>,
typename TileCopy_::CopyL0CToGm
>::type;
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
using LayoutCInL0 = layout::zN;
using L1AAlignHelper = Gemm::helper::L1AlignHelper<ElementA, LayoutA>;
using L1BAlignHelper = Gemm::helper::L1AlignHelper<ElementB, LayoutB>;
static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES;
static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES;
static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES;
static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES;
static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES;
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K;
// L1 tile size
static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t);
// L0 tile size
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator);
// Check LayoutC
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
// Check L1TileShape
static_assert(
(std::is_same_v<ElementA, int8_t>
? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE
: (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE),
"L1TileShape exceeding the L1 space for the given data type"
);
// Check L0TileShape
static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!");
static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!");
static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!");
static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N,
"The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet");
static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout<ElementA>(
L1TileShape::M, L1TileShape::K);
static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout<ElementB>(
L1TileShape::K, L1TileShape::N);
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
{
syncGroupIdx = 0;
InitL1(resource, l1BufAddrStart);
InitL0A(resource);
InitL0B(resource);
InitL0C(resource);
}
CATLASS_DEVICE
~BlockMmad()
{
SynchronizeBlock();
for (uint32_t i = 0; i < L1_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
}
CATLASS_DEVICE
void operator()(
AscendC::GlobalTensor<ElementA> const &gmBlockA, LayoutA const &layoutA,
AscendC::GlobalTensor<ElementB> const &gmBlockB, LayoutB const &layoutB,
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutC,
AscendC::GlobalTensor<uint64_t> const &gmBlockS, layout::VectorLayout const &layoutScale,
GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0
)
{
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
uint32_t mRound = RoundUp<L1AAlignHelper::M_ALIGNED>(actualShape.m());
uint32_t nRound = RoundUp<L1BAlignHelper::N_ALIGNED>(actualShape.n());
uint32_t startTileIdx = 0;
if constexpr (ENABLE_SHUFFLE_K) {
startTileIdx = AscendC::GetBlockIdx() % kTileCount;
}
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) {
uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ?
(startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount);
uint32_t kActual = (kTileIdx < kTileCount - 1) ?
L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K);
// Emission load instruction from GM to L1
MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K};
MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0};
auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)];
auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)];
// Load first matrix A tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
// Load first matrix B tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);
// If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile
if (preloadCount == PRELOAD_STAGES) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
}
// Store the current load status
uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ?
(l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES);
auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId];
l1TileMmadParams.l1ListId = l1ListId;
l1TileMmadParams.mRound = mRound;
l1TileMmadParams.nRound = nRound;
l1TileMmadParams.kActual = kActual;
l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0);
l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1);
l1TileMmadParams.flag = flag;
if (kLoopIdx == kTileCount - 1) {
l1TileMmadParams.gmBlockC = gmBlockC;
l1TileMmadParams.gmBlockS = gmBlockS;
l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN());
l1TileMmadParams.layoutScale = layoutScale;
l1TileMmadParams.syncLoopIdx = syncLoopIdx;
}
if (preloadCount < PRELOAD_STAGES) {
++preloadCount;
} else {
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
}
l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0;
}
}
CATLASS_DEVICE
void SynchronizeBlock()
{
while (preloadCount > 0) {
L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]);
l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0;
--preloadCount;
}
}
CATLASS_DEVICE
void Finalize(int32_t target, int32_t flag = 0)
{
for(;syncGroupIdx <= target; syncGroupIdx++) {
int32_t flagId = syncGroupIdx / 8 + flag;
AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId);
}
}
private:
struct L1TileMmadParams {
uint32_t l1ListId;
uint32_t mRound;
uint32_t nRound;
uint32_t kActual;
bool isKLoopFirst;
bool isKLoopLast;
AscendC::GlobalTensor<ElementC> gmBlockC;
AscendC::GlobalTensor<uint64_t> gmBlockS;
LayoutC layoutCInGm;
layout::VectorLayout layoutScale;
int32_t syncLoopIdx;
int32_t flag;
CATLASS_DEVICE
L1TileMmadParams() = default;
};
CATLASS_DEVICE
void InitL1(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart)
{
uint32_t l1AOffset = l1BufAddrStart;
uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES;
for (uint32_t i = 0; i < L1_STAGES; ++i) {
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_TILE_SIZE * i);
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_TILE_SIZE * i);
l1AEventList[i] = i;
l1BEventList[i] = i + L1_STAGES;
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
}
if constexpr (std::is_same_v<ElementA, int8_t>) {
uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES;
l1STensor = resource.l1Buf.template GetBufferByByte<uint64_t>(l1SOffset);
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
}
CATLASS_DEVICE
void InitL0A(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0A_STAGES; ++i) {
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_TILE_SIZE * i);
l0AEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
}
}
CATLASS_DEVICE
void InitL0B(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0B_STAGES; ++i) {
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_TILE_SIZE * i);
l0BEventList[i] = i + L0A_STAGES;
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
}
CATLASS_DEVICE
void InitL0C(Arch::Resource<ArchTag> &resource)
{
for (uint32_t i = 0; i < L0C_STAGES; ++i) {
l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(L0C_TILE_SIZE * i);
l0CEventList[i] = i;
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[i]);
}
}
CATLASS_DEVICE
void L1TileMmad(L1TileMmadParams const &params)
{
uint32_t mPartLoop = CeilDiv<L0TileShape::M>(params.mRound);
uint32_t nPartLoop = CeilDiv<L0TileShape::N>(params.nRound);
uint32_t kPartLoop = CeilDiv<L0TileShape::K>(params.kActual);
auto &l1ATensor = l1ATensorList[params.l1ListId];
auto &l1BTensor = l1BTensorList[params.l1ListId];
auto &l0CTensor = l0CTensorList[l0CListId];
LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound));
if constexpr (!ENABLE_UNIT_FLAG) {
if (params.isKLoopFirst) {
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
}
}
for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) {
uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ?
L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M);
for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) {
uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ?
L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K);
auto &l0ATile = l0ATensorList[l0AListId];
auto layoutAInL0 = LayoutAInL0::template MakeLayout<ElementA>(mPartActual, kPartActual);
auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK();
auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
if ((mPartIdx == 0) && (kPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[params.l1ListId]);
}
copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT);
if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[params.l1ListId]);
}
for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) {
uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ?
L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N);
auto &l0BTile = l0BTensorList[l0BListId];
auto layoutBInL0 = LayoutBInL0::template MakeLayout<ElementB>(kPartActual, nPartActual);
auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN();
auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)];
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
if ((kPartIdx == 0) && (nPartIdx == 0)) {
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[params.l1ListId]);
}
copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT);
if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[params.l1ListId]);
}
AscendC::SetFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN();
auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)];
AscendC::WaitFlag<AscendC::HardEvent::MTE1_M>(EVENT_ID0);
// If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0
bool initC = (params.isKLoopFirst && (kPartIdx == 0));
// If the unit flag is enabled, the unit flag is set according to the calculation progress
uint8_t unitFlag = 0b00;
if constexpr (ENABLE_UNIT_FLAG) {
if (params.isKLoopLast &&
(mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) {
unitFlag = 0b11;
} else {
unitFlag = 0b10;
}
}
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[l0BListId]);
l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0;
}
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[l0AListId]);
l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0;
}
}
if (params.isKLoopLast) {
auto layoutCInGm = params.layoutCInGm;
if constexpr (std::is_same_v<ElementA, int8_t>) {
auto layoutScale = params.layoutScale;
auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1)));
AscendC::WaitFlag<AscendC::HardEvent::FIX_MTE2>(0);
copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS);
AscendC::SetFlag<AscendC::HardEvent::MTE2_FIX>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_FIX>(0);
}
if constexpr (!ENABLE_UNIT_FLAG) {
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(l0CEventList[l0CListId]);
if constexpr (std::is_same_v<ElementA, int8_t>) {
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0);
} else {
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0);
}
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(l0CEventList[l0CListId]);
} else {
if constexpr (std::is_same_v<ElementA, int8_t>) {
copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11);
} else {
copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11);
}
}
l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0;
if constexpr (std::is_same_v<ElementA, int8_t>) {
AscendC::SetFlag<AscendC::HardEvent::FIX_MTE2>(0);
}
Finalize(params.syncLoopIdx, params.flag);
}
}
AscendC::LocalTensor<ElementA> l1ATensorList[L1_STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[L1_STAGES];
AscendC::LocalTensor<uint64_t> l1STensor;
int32_t syncGroupIdx;
int32_t l1AEventList[L1_STAGES];
int32_t l1BEventList[L1_STAGES];
uint32_t l1ListId{0};
AscendC::LocalTensor<ElementA> l0ATensorList[L0A_STAGES];
int32_t l0AEventList[L0A_STAGES];
uint32_t l0AListId{0};
AscendC::LocalTensor<ElementB> l0BTensorList[L0B_STAGES];
int32_t l0BEventList[L0B_STAGES];
uint32_t l0BListId{0};
AscendC::LocalTensor<ElementAccumulator> l0CTensorList[L0C_STAGES_];
int32_t l0CEventList[L0C_STAGES_];
uint32_t l0CListId{0};
L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES];
uint32_t l1TileMmadParamsId{0};
uint32_t preloadCount{0};
TileMmad tileMmad;
CopyGmToL1A copyGmToL1A;
CopyGmToL1B copyGmToL1B;
CopyGmToL1S copyGmToL1S;
CopyL1ToL0A copyL1ToL0A;
CopyL1ToL0B copyL1ToL0B;
CopyL0CToGm copyL0CToGm;
};
} // namespace Catlass::Gemm::Block
#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP

View File

@@ -0,0 +1,9 @@
#ifndef CONST_ARGS_HPP
#define CONST_ARGS_HPP
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
constexpr static int32_t NUMS_PER_FLAG = 16;
constexpr static int32_t CACHE_LINE = 512;
constexpr static int32_t RESET_VAL = 0xffff;
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
#endif

View File

@@ -0,0 +1,40 @@
#ifndef COPY_GM_TO_L1_CUSTOM_HPP
#define COPY_GM_TO_L1_CUSTOM_HPP
namespace Catlass::Gemm::Tile {
/// Partial specialization for nZ in and nZ out.
template <
class ArchTag,
class Element
>
struct CopyGmToL1<ArchTag, Gemm::GemmType<Element, layout::VectorLayout>> {
using LayoutDst = layout::VectorLayout;
using LayoutSrc = layout::VectorLayout;
static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4
// Mehtods
CATLASS_DEVICE
CopyGmToL1() {};
CATLASS_DEVICE
void operator()(
AscendC::LocalTensor<Element> const &dstTensor,
AscendC::GlobalTensor<Element> const &srcTensor,
LayoutDst const &layoutDst, LayoutSrc const &layoutSrc)
{
uint32_t blockCount = 1;
uint32_t blockLen = CeilDiv<ELE_NUM_PER_C0>(layoutSrc.shape(0));
AscendC::DataCopyParams repeatParams;
repeatParams.blockCount = blockCount;
repeatParams.blockLen = blockLen;
repeatParams.srcStride = 0;
repeatParams.dstStride = 0;
AscendC::DataCopy(dstTensor, srcTensor, repeatParams);
}
};
}
#endif // COPY_GM_TO_L1_CUSTOM_HPP

View File

@@ -0,0 +1,47 @@
#ifndef COPY_L0C_TO_GM_CUSTOM_HPP
#define COPY_L0C_TO_GM_CUSTOM_HPP
namespace Catlass::Gemm::Tile {
template <
class ElementAccumulator_,
class ElementDst_,
bool ReluEnable_
>
struct CopyL0CToGm<Catlass::Arch::AtlasA2,
ElementAccumulator_,
Gemm::GemmType<ElementDst_, layout::RowMajor>,
ScaleGranularity::PER_CHANNEL,
ReluEnable_>
{
using ArchTag = Catlass::Arch::AtlasA2;
using ElementDst = ElementDst_;
using ElementSrc = ElementAccumulator_;
using LayoutSrc = Catlass::layout::zN;
using LayoutDst = Catlass::layout::RowMajor;
static constexpr auto quantPre = CopyL0CToGmQuantMode<ArchTag, ElementSrc, ElementDst,
ScaleGranularity::PER_CHANNEL>::VALUE;
static constexpr auto reluEn = ReluEnable_;
CATLASS_DEVICE
void operator()(AscendC::GlobalTensor<ElementDst> const &dst, AscendC::LocalTensor<ElementSrc> const &src, AscendC::LocalTensor<uint64_t> cbufWorkspace,
LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0)
{
AscendC::FixpipeParamsV220 intriParams;
// Fixpipe layout information
intriParams.nSize = dstLayout.shape(1);
intriParams.mSize = dstLayout.shape(0);
intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0);
intriParams.dstStride = dstLayout.stride(0);
// Fixpipe auxiliary arguments
intriParams.quantPre = quantPre;
intriParams.reluEn = reluEn;
intriParams.unitFlag = unitFlag;
// Call AscendC Fixpipe
AscendC::Fixpipe<ElementDst, ElementSrc, AscendC::CFG_ROW_MAJOR>(dst, src, cbufWorkspace, intriParams);
}
};
}
#endif // COPY_L0C_TO_GM_CUSTOM_HPP

View File

@@ -0,0 +1,53 @@
#ifndef DISPATH_POLICY_CUSTOM_HPP
#define DISPATH_POLICY_CUSTOM_HPP
namespace Catlass::Gemm {
template <bool ENABLE_UNIT_FLAG_ = false, bool ENABLE_SHUFFLE_K_ = false>
struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 {
static constexpr uint32_t STAGES = 2;
static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_;
static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_;
};
template <uint32_t PRELOAD_STAGES_, uint32_t L1_STAGES_, uint32_t L0A_STAGES_, uint32_t L0B_STAGES_,
uint32_t L0C_STAGES_, bool ENABLE_UNIT_FLAG_, bool ENABLE_SHUFFLE_K_>
struct MmadAtlasA2PreloadAsyncFixpipe :
public MmadAtlasA2PreloadAsync<
PRELOAD_STAGES_,
L1_STAGES_,
L0A_STAGES_,
L0B_STAGES_,
L0C_STAGES_,
ENABLE_UNIT_FLAG_,
ENABLE_SHUFFLE_K_
> {
};
}
namespace Catlass::Epilogue {
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2UnQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantSwigluQuant {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
template <uint32_t UB_STAGES_>
struct EpilogueAtlasA2PerTokenDequantV2 {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
};
}
#endif // DISPATH_POLICY_CUSTOM_HPP

View File

@@ -0,0 +1,16 @@
#ifndef GET_TENSOR_ADDR_HPP
#define GET_TENSOR_ADDR_HPP
#include "kernel_operator.h"
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
template <typename T>
FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) {
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
}
#endif // GET_TENSOR_ADDR_HPP

View File

@@ -0,0 +1,195 @@
#ifndef SYNC_UTIL_HPP
#define SYNC_UTIL_HPP
#include "kernel_operator.h"
#include "const_args.hpp"
#include "moe_distribute_base.h"
#ifndef HCCL_COMM
#include "shmem_api.h"
using namespace AscendC::HcclContextDef;
#endif
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
constexpr int32_t MAX_RANK_SIZE = 32;
constexpr int32_t SHMEM_MEM = 1024 * MB_SIZE;
FORCE_INLINE_AICORE void AicSyncAll() {
AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8);
AscendC::CrossCoreWaitFlag<0x0>(8);
}
template<typename T>
FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) {
*((__gm__ T *)addr) = val;
}
template<typename T>
FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) {
return *((__gm__ T *)cache);
}
FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) {
using namespace AscendC;
GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(addr);
// Important: add hint to avoid dcci being optimized by compiler
__asm__ __volatile__("");
DataCacheCleanAndInvalid<uint8_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
}
FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
gm_dcci((__gm__ uint8_t *)sig_addr);
if (*sig_addr == cmp_val) {
return *sig_addr;
}
if (*sig_addr == cmp_val + 1) {
return *sig_addr;
}
} while (true);
return -1;
}
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
do {
AscendC::LocalTensor<int32_t> ub;
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
ub.address_.bufferAddr = 0;
AscendC::GlobalTensor<int32_t> sig;
sig.SetGlobalBuffer(sig_addr);
AscendC::DataCopy(ub, sig, 8);
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
if (ub(0) != cmp_val) {
return;
}
} while (true);
return;
}
class HcclShmem {
public:
#ifdef HCCL_COMM
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
FORCE_INLINE_AICORE
HcclShmem(){
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
m_rank = WinContext_->localUsrRankId;
m_rankSize = WinContext_->rankSize;
m_segmentSize = WinContext_->winSize;
for (int i = 0; i < m_rankSize; i++) {
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
}
}
#else
FORCE_INLINE_AICORE
HcclShmem(){
m_segmentSize = SHMEM_MEM;
}
FORCE_INLINE_AICORE
void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) {
symmetricPtr = symmetricPtr_;
m_rank = rank;
m_rankSize = rankSize;
}
#endif
FORCE_INLINE_AICORE
GM_ADDR operator() () const {
#ifdef HCCL_COMM
return m_ptrArray[m_rank];
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, m_rank));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator() (int32_t index) const {
#ifdef HCCL_COMM
return m_ptrArray[index];
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr(symmetricPtr, index));
#endif
}
FORCE_INLINE_AICORE
GM_ADDR operator () (int64_t offset, int32_t rankId) const {
#ifdef HCCL_COMM
if (offset < 0 || offset >= m_segmentSize) {
return nullptr;
}
if (rankId < 0 || rankId >= m_rankSize) {
return nullptr;
}
return m_ptrArray[rankId] + offset;
#else
return reinterpret_cast<GM_ADDR>(shmem_ptr((symmetricPtr + offset), rankId));
#endif
}
FORCE_INLINE_AICORE
size_t SegmentSize() const {
return m_segmentSize;
}
FORCE_INLINE_AICORE
int32_t RankSize() const {
return m_rankSize;
}
FORCE_INLINE_AICORE
~HcclShmem() {
}
FORCE_INLINE_AICORE
void CrossRankSync() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
__gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset;
__gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048;
int count = gm_load(sync_base) + 1;
int vec_id = AscendC::GetBlockIdx();
int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation();
for(int i = vec_id; i < m_rankSize; i += vec_size) {
__gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16;
gm_store(sync_remote, count);
gm_dcci((__gm__ uint8_t*)sync_remote);
auto sync_check = sync_counter + i * 16;
gm_signal_wait_until_eq_for_barrier(sync_check, count);
}
AscendC::SyncAll<true>();
gm_store(sync_base, count);
}
FORCE_INLINE_AICORE
__gm__ int32_t* SyncBaseAddr() {
uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t);
return (__gm__ int32_t*)(*this)() + flag_offset + 2048;
}
private:
GM_ADDR symmetricPtr;
int32_t m_rank;
int32_t m_rankSize;
size_t m_segmentSize;
};
#endif

View File

@@ -0,0 +1,20 @@
#ifndef LAYOUT_3D_HPP
#define LAYOUT_3D_HPP
#include "kernel_operator.h"
#include "catlass/catlass.hpp"
class Layout3D {
int64_t strides[2];
public:
CATLASS_DEVICE
Layout3D() {}
CATLASS_DEVICE
Layout3D(int64_t stride0, int64_t stride1) {
strides[0] = stride0;
strides[1] = stride1;
}
CATLASS_DEVICE
int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) {
return dim0 * strides[0] + dim1 * strides[1] + dim2;
}
};
#endif // LAYOUT_3D_HPP

View File

@@ -0,0 +1,25 @@
#ifndef SELECT_HELPER_HPP
#define SELECT_HELPER_HPP
#include "catlass/layout/layout.hpp"
using namespace AscendC;
using namespace Catlass;
template <typename Layout, typename ElementType, typename = void>
struct LayoutBInitializer {
CATLASS_DEVICE
static Layout create(uint32_t k, uint32_t n) {
return Layout{k, n};
}
};
template <typename Layout, typename ElementType>
struct LayoutBInitializer<Layout, ElementType,
std::enable_if_t<std::is_same_v<Layout, layout::zN>>
> {
CATLASS_DEVICE
static Layout create(uint32_t k, uint32_t n) {
return Layout::template MakeLayout<ElementType>(k, n);
}
};
#endif // SELECT_HELPER_HPP