[Feature] DispatchGmmCombineDecode support bf16/float16 gmm1/gmm2 weight and support gmm weight with ND format (#6393)

### What this PR does / why we need it?
1. support ND format gmm weight input.
Before this pr, gmm1_weight and gmm2_weight could only be passed as
input to the DispatchGmmCombineDecode operator in NZ data format. After
the modification, they are allowed to be passed in ND data format.
2. support bf16/float16 gmm weight
The current PR modification enables the DispatchGmmCombineDecode
operator to support non-W8A8 scenarios, allowing gmm1_weight and
gmm2_weight to be passed as float16/bfloat16 which is correspond with
input token data type.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: lih827 <383084552@qq.com>
This commit is contained in:
lih827
2026-02-12 10:37:41 +08:00
committed by GitHub
parent f1ffb5fb19
commit f71812011d
18 changed files with 3766 additions and 237 deletions

View File

@@ -8,6 +8,7 @@
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "dispatch_gmm_combine_decode.h"
#include "dispatch_gmm_combine_decode_bf16_fp16.h"
#include <kernel_operator.h>
#include "lib/matmul_intf.h"
@@ -25,12 +26,28 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
#if (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_INT8)
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
DispatchGmmCombineDecode<
DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op;
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) ||
TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) ||
TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) {
DispatchGmmCombineDecodeImpl::DispatchGmmCombineDecode<
DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int8_t, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
op.Process();
}
#elif (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_BF16 || ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_FLOAT16)
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) ||
TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) ||
TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) {
DispatchGmmCombineDecodeBf16Fp16Impl::DispatchGmmCombineDecodeBf16Fp16<
DTYPE_GMM1_PERMUTED_WEIGHT, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, DTYPE_GMM1_PERMUTED_WEIGHT, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
op.Process();
}
#endif
}

View File

@@ -34,7 +34,7 @@
#include "dispatch_gmm_combine_decode_base.h"
using namespace Catlass;
namespace DispatchGmmCombineDecodeImpl {
using MmadAtlasA2Custom =
Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
CUSTOM_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG,
@@ -57,7 +57,9 @@ using Gmm2DispatchPolicy =
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
layout::RowMajor layoutA, GM_ADDR gmB,
typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB,
GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
@@ -73,7 +75,8 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::zN>;
using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type;
using BType = Gemm::GemmType<int8_t, LayoutB>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
@@ -107,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using ElementGroupList = int64_t;
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
@@ -178,7 +181,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
layout::RowMajor layoutA, GM_ADDR gmB,
typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB,
GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmWorkspace, void *combiner)
@@ -189,7 +194,8 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::zN>;
using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type;
using BType = Gemm::GemmType<int8_t, LayoutB>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
@@ -342,6 +348,16 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
gmm2InputDim_ = gmm1OutputDim_ / 2;
}
template<uint32_t EXEC_FLAG>
__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) {
if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) {
MatrixCoord mc{k, n};
return layout::RowMajor::template MakeLayoutInUb<int8_t>(mc);
} else {
return layout::zN::template MakeLayout<int8_t>(k, n);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
{
@@ -349,11 +365,11 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_};
layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_};
layout::zN layoutWeight1 = layout::zN::template MakeLayout<int8_t>(tokenHiddenSize_, gmm1OutputDim_);
auto layoutWeight1 = CreateWeightLayout<EXEC_FLAG>(tokenHiddenSize_, gmm1OutputDim_);
layout::VectorLayout layoutW1Scale{gmm1OutputDim_};
layout::VectorLayout layoutX1Scale{maxTokenNum_};
layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_};
layout::zN layoutWeight2 = layout::zN::template MakeLayout<int8_t>(gmm2InputDim_, gmm2OutputDim_);
auto layoutWeight2 = CreateWeightLayout<EXEC_FLAG>(gmm2InputDim_, gmm2OutputDim_);
layout::VectorLayout layoutW2Scale{gmm2OutputDim_};
layout::VectorLayout layoutX2Scale{maxTokenNum_};
layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_};
@@ -436,4 +452,5 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut,
layoutOutput, gmWorkspace, &combiner);
}
} // namespace DispatchGmmCombineDecodeImpl
#endif // DISPATCH_GMM_COMBINE_DECODE_H

View File

@@ -12,3 +12,6 @@
#include "block_epilogue_per_token_dequant_swiglu.h"
#include "block_epilogue_per_token_dequant.hpp"
#include "block_epilogue_swiglu_bf16_fp16.h"
#include "block_epilogue_bf16_fp16.hpp"

View File

@@ -0,0 +1,337 @@
/*
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP
#define ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP
#include "../../raw_distributed/cam_moe_distribute_combine.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/matrix_coord.hpp"
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_, class DType_,
class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2Combine<UB_STAGES_, EXEC_FLAG_>,
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
TileCopy_, EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2Combine<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementRawScale = ScaleType_;
using ElementFp32Scale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, float> &&
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>),
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) +
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementRawScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE void AlignUbOffset()
{
size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1);
if (ubMask != 0) {
ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask;
}
}
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo,
Params const &params = Params{})
: resource(resource), calcInfo(calcInfo), params(params)
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
AlignUbOffset();
epSendCountLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t);
AlignUbOffset();
AscendC::GlobalTensor<int32_t> epSendCountGM;
epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_);
uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_;
AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast<uint32_t>(epSendCountSize * sizeof(uint32_t)),
0U, 0U, 0U};
AscendC::DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
}
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U)
{
return (GM_ADDR)((calcInfo.epRankId_ == rankId)
? calcInfo.epWinContext_->localWindowsIn
: ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr))
->windowsIn) +
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
}
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
{
if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) {
remoteEpRank = calcInfo.epRankId_;
localEpRank = epRank;
} else {
remoteEpRank = epRank;
localEpRank = calcInfo.epRankId_;
}
}
CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor<ElementD> &ubD, layout::RowMajor &layoutGmTileD,
LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD)
{
const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD);
const uint32_t copyTokenSrcStride =
(layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD));
const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD);
int64_t offsetD = groupOffsetD + tileOffsetD;
uint32_t startToken = offsetD / calcInfo.axisH_;
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
uint32_t itToken = startToken;
uint32_t endToken = startToken + layoutGmTileD.shape(0);
constexpr uint32_t epRankStart = 0;
uint32_t sendCount =
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
uint32_t prevSendCount = sendCount;
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
if (prevSendCount <= itToken && itToken < sendCount) {
uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken;
AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride,
copyTokenDstStride, 0);
uint32_t remoteEpRank;
uint32_t localEpRank;
SetCombineSendEpRank(epRank, remoteEpRank, localEpRank);
GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) +
localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_;
AscendC::GlobalTensor<ElementD> rankWindow;
rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM);
AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset],
ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams);
itToken += copyTokenCount;
}
}
}
CATLASS_DEVICE
void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK,
GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK,
AscendC::GlobalTensor<ElementC> const &gmBlockC, LayoutC const &layoutBlockC,
Callback &&callback = Callback{})
{
if (actualBlockShapeMNK.k() == 0) {
return;
}
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
expertOffset = expertIdx * calcInfo.epWorldSize_;
}
callback();
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementRawScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
uint32_t subblockNum = AscendC::GetSubBlockNum();
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
auto tileOffsetD = params.layoutD.GetOffset(tileOffset);
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD);
} else {
auto gmTileD = gmD[tileOffsetD];
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
}
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
Arch::Resource<ArchTag> &resource;
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbScaleVMTE2List[UB_STAGES];
int32_t eventUbScaleMTE2VList[UB_STAGES];
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
AscendC::LocalTensor<int32_t> epSendCountLocal_;
size_t ubOffset{0};
int32_t eventVMTE2{0};
int32_t eventMTE2V{0};
int32_t eventMTE3V{0};
int32_t eventVMTE3{0};
int32_t eventVS{0};
int32_t eventMTE2S{0};
uint32_t expertOffset;
uint32_t ubListId{0};
CopyGmToUbC copyGmToUbC;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block
#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP

View File

@@ -0,0 +1,259 @@
/*
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/detail/callback.hpp"
#include "../tile/tile_stride_muls.h"
#include "../tile/tile_stride_binary.h"
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2Swiglu<UB_STAGES_, EXEC_FLAG_>,
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
TileCopy_, EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2Swiglu<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementRawScale = ScaleType_;
using ElementFp32Scale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, float> && std::is_same_v<ElementD, float>,
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0,
"The per token scale granularity for word calculation must be 32 bytes aligned.");
static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts.");
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) +
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementRawScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
int32_t eventMTE3MTE2 = 0;
int32_t eventMTE2MTE3 = 0;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
eventUbMTE3MTE2List[i] = eventMTE3MTE2++;
eventUbMTE2MTE3List[i] = eventMTE2MTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventUbMTE3MTE2List[i]);
}
ubDenominatorMxN = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventUbMTE3MTE2List[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
{
if (0 == actualBlockShapeMNK.k()) {
return;
}
callback();
ubListId = 0;
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
AscendC::GlobalTensor<ElementRawScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = 0; // for 1C1V
uint32_t subblockNum = 1; // for 1C1V
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
if (isLeft) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Muls(ubDenominatorMxN, ubC, -1.0f, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT);
AscendC::PipeBarrier<PIPE_V>();
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Div(ubD, ubC, ubDenominatorMxN, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
} else {
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(eventUbMTE3MTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(eventUbMTE2MTE3List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(eventUbMTE2MTE3List[ubListId]);
copyUbToGmD(gmTileD, ubC, layoutGmTileD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(eventUbMTE3MTE2List[ubListId]);
}
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
int32_t eventUbMTE3MTE2List[UB_STAGES];
int32_t eventUbMTE2MTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubDenominatorMxN;
TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
CopyUbToGmD copyUbToGmD;
};
} // namespace Catlass::Epilogue::Block

View File

@@ -26,4 +26,17 @@ struct EpilogueAtlasA2PerTokenDequantCombine {
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
};
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
struct EpilogueAtlasA2Swiglu {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
};
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_>
struct EpilogueAtlasA2Combine {
using ArchTag = Arch::AtlasA2;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_;
};
} // namespace Catlass::Epilogue

View File

@@ -0,0 +1,383 @@
/*
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP
#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP
#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h"
#include "../../raw_distributed/cam_moe_distribute_combine.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/coord.hpp"
#include "catlass/detail/callback.hpp"
#include "catlass/gemm_coord.hpp"
#include "catlass/matrix_coord.hpp"
namespace Catlass::Gemm::Kernel {
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_,
uint32_t WORKSPACE_STAGES_, class ElementGroupList_>
class GroupedMatmulSliceMMultiStageWorkspace
{
public:
using BlockMmad = BlockMmad_;
using ArchTag = typename BlockMmad::ArchTag;
using L1TileShape = typename BlockMmad::L1TileShape;
using ElementA = typename BlockMmad::ElementA;
using LayoutA = typename BlockMmad::LayoutA;
using ElementB = typename BlockMmad::ElementB;
using LayoutB = typename BlockMmad::LayoutB;
using ElementC = typename BlockMmad::ElementC;
using LayoutC = typename BlockMmad::LayoutC;
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using BlockEpilogue = BlockEpilogue_;
using ElementScale = typename BlockEpilogue::ElementRawScale;
using LayoutScale = typename BlockEpilogue::LayoutScale;
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
using ElementD = typename BlockEpilogue::ElementD;
using LayoutD = typename BlockEpilogue::LayoutD;
using EpilogueParams = typename BlockEpilogue::Params;
using BlockScheduler = BlockScheduler_;
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
using ElementGroupList = ElementGroupList_;
/// Parameters structure
struct Params {
// Data members
GemmCoord problemShape;
uint32_t problemCount;
__gm__ ElementGroupList_ *ptrGroupList;
__gm__ ElementA *ptrA;
LayoutA layoutA;
__gm__ ElementB *ptrB;
LayoutB layoutB;
__gm__ ElementScale *ptrScale;
LayoutScale layoutScale;
__gm__ ElementPerTokenScale *ptrPerTokenScale;
LayoutPerTokenScale layoutPerTokenScale;
__gm__ ElementD *ptrD;
LayoutD layoutD;
GM_ADDR ptrWorkspace;
void *combiner;
// Methods
CATLASS_DEVICE
Params() {}
CATLASS_DEVICE
Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_,
GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_,
LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_,
void *combiner_)
: problemShape(problemShape_),
problemCount(problemCount_),
ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)),
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)),
layoutA(layoutA_),
ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)),
layoutB(layoutB_),
ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)),
layoutScale(layoutScale_),
ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)),
layoutD(layoutD_),
ptrWorkspace(ptrWorkspace_),
combiner(combiner_)
{}
};
// Methods
CATLASS_DEVICE
GroupedMatmulSliceMMultiStageWorkspace()
{
Arch::FlagID flagId = 0;
for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) {
flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++);
flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++);
aicWaitFuncList[stageId] = {this, stageId};
aicSetFuncList[stageId] = {this, stageId};
}
}
template <int32_t CORE_TYPE = g_coreType>
CATLASS_DEVICE void operator()(Params const &params);
template <>
CATLASS_DEVICE void operator()<AscendC::AIC>(Params const &params)
{
BlockScheduler blockScheduler;
BlockMmad blockMmad(resource);
// Represent the full gm
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer(params.ptrA);
AscendC::GlobalTensor<ElementB> gmB;
AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB));
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr<int32_t>(0)));
}
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
uint32_t coreIdx = AscendC::GetBlockIdx();
uint32_t coreNum = AscendC::GetBlockNum();
int64_t gmGroupOffsetA = 0;
int64_t gmGroupOffsetB = 0;
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
uint32_t stageId = 0;
uint32_t stageUsed = 0;
uint32_t startCoreIdx = 0;
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(
gmBlistTensorDesc.GetDataPtr<int32_t>(groupIdx)));
}
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK());
LayoutB layoutB = params.layoutB;
blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N));
uint32_t coreLoops = blockScheduler.GetCoreLoops();
// Determine the starting loopIdx of the current core under the current
// groupIdx
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
// Loop through the matmul of each groupIdx
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
// Compute block location
GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord);
Callback callbackBeforeFixpipe{};
if (stageUsed == WORKSPACE_STAGES) {
callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]);
} else {
++stageUsed;
}
Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]);
// Compute initial location in logical coordinates
MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K};
MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N};
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
int64_t gmOffsetB = layoutB.GetOffset(offsetB);
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
// Compute block-scoped matrix multiply-add
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe);
} else {
callbackBeforeFixpipe();
blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB,
gmC[gmOffsetC], layoutC, actualBlockShape);
callbackAfterFixpipe();
}
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
}
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
if constexpr (BlockMmad::DispatchPolicy::ASYNC) {
blockMmad.SynchronizeBlock();
}
while (stageUsed > 0) {
uint32_t aivComputeStageId =
(stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed);
Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]);
--stageUsed;
}
}
template <>
CATLASS_DEVICE void operator()<AscendC::AIV>(Params const &params)
{
auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> *)params.combiner;
{
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
if (get_subblockid() == 0) {
AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID);
}
}
BlockScheduler blockScheduler;
BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo());
uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum();
uint32_t coreNum = AscendC::GetBlockNum();
int64_t gmGroupOffsetScale = 0;
int64_t gmGroupOffsetPerTokenScale = 0;
int64_t gmGroupOffsetD = 0;
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
AscendC::GlobalTensor<ElementC> gmC;
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace));
auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N};
uint32_t stageId = 0;
uint32_t startCoreIdx = 0;
AscendC::ListTensorDesc gmScaleListTensor;
gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale));
__gm__ ElementScale* gmScalePtr;
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr<int32_t>(0));
}
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
LayoutScale layoutScale = params.layoutScale;
LayoutPerTokenScale layoutPerTokenScale =
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());
EpilogueParams epilogueParams;
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(
gmScaleListTensor.GetDataPtr<int32_t>(groupIdx));
epilogueParams = EpilogueParams {
gmScalePtr, layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale,
params.ptrD + gmGroupOffsetD, layoutD};
} else {
epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale,
layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
params.ptrD + gmGroupOffsetD,
layoutD};
}
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
blockEpilogue.UpdateParams(epilogueParams);
uint32_t coreLoops = blockScheduler.GetCoreLoops();
GemmCoord blockShapeMNK = L1TileShape::ToCoord();
uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx;
for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) {
GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx);
GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK);
MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
auto gmBlockC = gmC[gmOffsetC];
auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN());
Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]);
blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC,
layoutBlockC);
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]);
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetScale += inGroupProblemShape.n();
}
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n();
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
}
icache_preload(4);
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
if (get_subblockid() == 0) {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->AllToAllSend();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
} else {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->ReducePermute();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
}
} else {
resource.pipe.Init();
combiner->TPipeSet(&resource.pipe);
combiner->Process();
combiner->TPipeSet(nullptr);
resource.pipe.Destroy();
}
}
private:
friend struct AicWaitFunc;
friend struct AicSetFunc;
struct AicWaitFunc {
using MatmulKernel =
GroupedMatmulSliceMMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicWaitFunc() = default;
CATLASS_DEVICE
void operator()() const
{
Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]);
}
MatmulKernel *ptr{nullptr};
uint32_t stageId;
};
struct AicSetFunc {
using MatmulKernel =
GroupedMatmulSliceMMultiStageWorkspace<TemplateMC2TypeFunc, BlockMmad, BlockEpilogue,
BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicSetFunc() = default;
CATLASS_DEVICE
void operator()() const
{
Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]);
}
MatmulKernel *ptr{nullptr};
uint32_t stageId;
};
Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES];
Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES];
AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES];
AicSetFunc aicSetFuncList[WORKSPACE_STAGES];
Arch::Resource<ArchTag> resource;
};
} // namespace Catlass::Gemm::Kernel
#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP

View File

@@ -22,51 +22,6 @@
#include "../../../dispatch_gmm_combine_decode_base.h"
constexpr uint32_t STATE_OFFSET = 512;
constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024;
constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024;
constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024;
constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024;
constexpr uint32_t UB_ALIGN = 32;
constexpr uint32_t TOKEN_EXTRA_SPACE = 512;
constexpr uint32_t INT32_COUNT_PER_BLOCK = 8;
constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512;
constexpr int64_t LOOP_TMP_SIZE = 4096;
constexpr int32_t SUB_AIV_NUM = 2;
constexpr int32_t ODD_EVEN_BASE = 2;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t GATHER_SECOND_NUM = 2;
constexpr uint32_t MAX_QUANT_ROW_ONCE = 8;
constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant
#define OPT_RANK_OFFSET 512
#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN)
#define CEIL(x, y) (((x) + (y - 1)) / (y))
#define UB_BLOCK_SIZE (32)
#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \
(((epRankId == rankId) \
? ((GM_ADDR)(winContext_->localWindowsExp)) \
: ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \
dataState * WIN_STATE_OFFSET)
#define GET_WIND_ADDR_BY_RANK_ID(rankId) \
(((epRankId == rankId) \
? ((GM_ADDR)(winContext_->localWindowsIn)) \
: ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \
winDataSizeOffset + rankId * OPT_RANK_OFFSET)
#define TOKEN_FLAG_1 (0x55555555)
#define TOKEN_FLAG_2 (0x33333333)
#define V_TO_C_FLAG_1 (0x03030303)
#define V_TO_C_FLAG_2 (0x05050505)
#define CV_FLAG_INDEX 0
#define GROUP_ID_INDEX 1
#define PRE_COUNT_INDEX 2
#define SELF_COUNT_INDEX 3
#define TOTAL_COUNT_INDEX 4
#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX
#define GROUP_INFO_SIZE 32
namespace Catlass::Gemm::Kernel {
template <class ArchTag>
@@ -306,54 +261,6 @@ private:
Epilogue::Tile::CopyUb2Gm<ArchTag, OutputType> copyUbToGmOutput;
};
__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx)
{
// flag++, like set flag
AscendC::PipeBarrier<PIPE_ALL>();
AscendC::GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE);
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
global);
__asm__ __volatile__("");
uint8_t value = global.GetValue(0);
global.SetValue(0, value + 1);
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
global);
__asm__ __volatile__("");
AscendC::PipeBarrier<PIPE_ALL>();
}
__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target)
{
// check flag, like wait flag
AscendC::PipeBarrier<PIPE_ALL>();
AscendC::GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE);
while (true) {
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
uint8_t value = global.GetValue(0);
if (value >= target) {
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
break;
}
}
AscendC::PipeBarrier<PIPE_ALL>();
}
__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
{
row = QUANT_SPACE_FACTOR / column;
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
}
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace

View File

@@ -9,7 +9,9 @@
*/
#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H
#define CAM_MOE_DISTRIBUTE_COMBINE_H
#ifndef OPT_RANK_OFFSET
#define OPT_RANK_OFFSET 512
#endif
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"

View File

@@ -10,7 +10,9 @@
#ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H
#define CAM_MOE_DISTRIBUTE_DISPATCH_H
#ifndef OPT_RANK_OFFSET
#define OPT_RANK_OFFSET 512
#endif
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"

View File

@@ -12,10 +12,107 @@
#include "../common/moe_distribute_base.h"
#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename WType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, WType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater, uint32_t EXEC_FLAG
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG
constexpr uint32_t STATE_OFFSET = 512;
constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024;
constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024;
constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024;
constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024;
constexpr uint32_t UB_ALIGN = 32;
constexpr uint32_t TOKEN_EXTRA_SPACE = 512;
constexpr uint32_t INT32_COUNT_PER_BLOCK = 8;
constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512;
constexpr int64_t LOOP_TMP_SIZE = 4096;
constexpr int32_t SUB_AIV_NUM = 2;
constexpr int32_t ODD_EVEN_BASE = 2;
constexpr int32_t BUFFER_NUM = 2;
constexpr int32_t GATHER_SECOND_NUM = 2;
constexpr uint32_t MAX_QUANT_ROW_ONCE = 8;
constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant
#ifndef OPT_RANK_OFFSET
#define OPT_RANK_OFFSET 512
#endif
#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN)
#define CEIL(x, y) (((x) + (y - 1)) / (y))
#define UB_BLOCK_SIZE (32)
#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \
(((epRankId == rankId) \
? ((GM_ADDR)(winContext_->localWindowsExp)) \
: ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \
dataState * WIN_STATE_OFFSET)
#define GET_WIND_ADDR_BY_RANK_ID(rankId) \
(((epRankId == rankId) \
? ((GM_ADDR)(winContext_->localWindowsIn)) \
: ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \
winDataSizeOffset + rankId * OPT_RANK_OFFSET)
#define TOKEN_FLAG_1 (0x55555555)
#define TOKEN_FLAG_2 (0x33333333)
#define V_TO_C_FLAG_1 (0x03030303)
#define V_TO_C_FLAG_2 (0x05050505)
#define CV_FLAG_INDEX 0
#define GROUP_ID_INDEX 1
#define PRE_COUNT_INDEX 2
#define SELF_COUNT_INDEX 3
#define TOTAL_COUNT_INDEX 4
#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX
#define GROUP_INFO_SIZE 32
__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx)
{
// flag++, like set flag
AscendC::PipeBarrier<PIPE_ALL>();
AscendC::GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE);
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
global);
__asm__ __volatile__("");
uint8_t value = global.GetValue(0);
global.SetValue(0, value + 1);
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
global);
__asm__ __volatile__("");
AscendC::PipeBarrier<PIPE_ALL>();
}
__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target)
{
// check flag, like wait flag
AscendC::PipeBarrier<PIPE_ALL>();
AscendC::GlobalTensor<uint8_t> global;
global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE);
while (true) {
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
uint8_t value = global.GetValue(0);
if (value >= target) {
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(global);
__asm__ __volatile__("");
break;
}
}
AscendC::PipeBarrier<PIPE_ALL>();
}
__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
{
row = QUANT_SPACE_FACTOR / column;
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
}
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H

View File

@@ -0,0 +1,457 @@
/*
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H
#define DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H
#include "lib/matmul_intf.h"
#include <kernel_operator.h>
#include "catlass/catlass.hpp"
#include "catlass/arch/arch.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/epilogue/tile/tile_broadcast_mul.hpp"
#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp"
#include "catlass/epilogue/tile/tile_swizzle.hpp"
#include "catlass/gemm/block/block_swizzle.hpp"
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h"
#include "catlass/gemm/gemm_type.hpp"
#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h"
#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h"
#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h"
#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h"
#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h"
#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h"
#include "dispatch_gmm_combine_decode_tiling.h"
#include "dispatch_gmm_combine_decode_base.h"
using namespace Catlass;
namespace DispatchGmmCombineDecodeBf16Fp16Impl {
using MmadAtlasA2Custom =
Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
CUSTOM_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG,
CUSTOM_ENABLE_SHUFFLE_K>;
using Gmm1L1TileShape = GemmShape<FP16_BF16_L1M, FP16_BF16_L1N, GMM1_L1K>;
using Gmm1L0TileShape = GemmShape<Gmm1L1TileShape::M, Gmm1L1TileShape::N, GMM1_L0K>;
using Gmm1EpilogueTileShape = MatrixShape<GMM1_EPIM, Gmm1L1TileShape::N>;
using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM1_SWIZZLE_OFFSET, GMM1_SWIZZLE_DIRECTION>;
using Gmm2L1TileShape = GemmShape<FP16_BF16_GMM2_L1M, FP16_BF16_GMM2_L1N, GMM2_L1K>;
using Gmm2L0TileShape = GemmShape<Gmm2L1TileShape::M, Gmm2L1TileShape::N, GMM2_L0K>;
using Gmm2EpilogueTileShape = MatrixShape<GMM2_EPIM, Gmm2L1TileShape::N>;
using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<GMM2_SWIZZLE_OFFSET, GMM2_SWIZZLE_DIRECTION>;
using Gmm2DispatchPolicy =
Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA<CUSTOM_PRELOAD_STAGES, GMM2_L1A_STAGES, GMM2_L1B_STAGES,
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB,
typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB,
GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = DispatchPolicy_;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type;
using BType = Gemm::GemmType<WType, LayoutB>;
using CType = Gemm::GemmType<float, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Swiglu<ubStages, 0>;
using ScaleType = Gemm::GemmType<W1ScaleType, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<float, layout::RowMajor>;
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using EpilogueTileShape = EpilogueTileShape_;
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
using TileBroadcastOneBlk =
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
using TileOneBlkColumnBroadcastMul =
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
using BlockScheduler = BlockScheduler_;
// kernel level
using ElementGroupList = int64_t;
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0,
Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspace<
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch<
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0) {
typename GemmKernel::Params params{problemShape,
groupCount,
gmGroupList,
gmA,
layoutA,
gmB,
layoutB,
gmScale,
layoutScale,
gmPerTokenScale,
layoutPerTokenScale,
gmD,
layoutD,
gmDequantScale,
layoutDequantScale,
gmWorkspace,
gmX,
debugGm,
gmexpertIds,
gmExpandIdx,
gmEpSendCount,
xActiveMask,
gmResvered,
gmExpertTokenNums,
epRankSize,
epRankId,
moeExpertNum,
moeExpertNumPerRank,
sharedExpertNum,
sharedExpertRankNum,
quantMode,
globalBs,
bs,
topK,
tokenLen};
// call a kernel
GemmKernel gemm;
gemm(params);
} else {
typename GemmKernel::Params params{problemShape,
groupCount,
gmGroupList,
gmA,
layoutA,
gmB,
layoutB,
gmScale,
layoutScale,
gmPerTokenScale,
layoutPerTokenScale,
gmD,
layoutD,
gmDequantScale,
layoutDequantScale,
gmWorkspace};
// call a kernel
GemmKernel gemm;
gemm(params);
}
}
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB,
typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB,
GM_ADDR gmScale,
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmWorkspace, void *combiner)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = DispatchPolicy_;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using AType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type;
using BType = Gemm::GemmType<WType, LayoutB>;
using CType = Gemm::GemmType<float, layout::RowMajor>;
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Combine<ubStages, EXEC_FLAG>;
using ScaleType = Gemm::GemmType<W2ScaleType, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
using RowBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using BroadcastOneBlkType = Gemm::GemmType<float, layout::RowMajor>;
using OneBlkColumnBroadcastMulType = Gemm::GemmType<float, layout::RowMajor>;
using EpilogueTileShape = EpilogueTileShape_;
using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul<ArchTag, RowBroadcastMulType, EpilogueTileShape>;
using TileBroadcastOneBlk =
Epilogue::Tile::TileBroadcastOneBlk<ArchTag, BroadcastOneBlkType, EpilogueTileShape::ROW>;
using TileOneBlkColumnBroadcastMul =
Epilogue::Tile::TileOneBlkColumnBroadcastMul<ArchTag, OneBlkColumnBroadcastMulType, EpilogueTileShape>;
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
using BlockScheduler = BlockScheduler_;
// kernel level
using ElementGroupList = int64_t;
using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMMultiStageWorkspace<
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
typename GemmKernel::Params params{
problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale,
layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner};
// call a kernel
GemmKernel gemm;
gemm(params);
}
template <TemplateMC2TypeClass>
class DispatchGmmCombineDecodeBf16Fp16
{
public:
__aicore__ inline DispatchGmmCombineDecodeBf16Fp16(){};
__aicore__ inline void Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR expertTokenNums,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process();
private:
GM_ADDR gmX_;
GM_ADDR gmexpertIds_;
GM_ADDR gmPermuteWeight1_;
GM_ADDR gmPermuteScale1_;
GM_ADDR gmWeight2_;
GM_ADDR gmScale2_;
GM_ADDR gmOutput_;
GM_ADDR gmExpertTokenNums_;
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
GM_ADDR xActiveMask_;
uint32_t maxTokenNum_{0};
uint32_t gmm1OutputDim_{0};
uint32_t tokenHiddenSize_{0};
uint32_t groupCount_{0};
uint32_t gmm2OutputDim_{0};
uint32_t gmm2InputDim_{0};
uint32_t globalRankId_{0};
uint32_t winSizePerRank_{0};
uint32_t blockDim_{0};
uint32_t epRankSize_{0};
uint32_t epRankId_{0};
uint32_t moeExpertNum_{0};
uint32_t moeExpertNumPerRank_{0};
uint32_t sharedExpertNum_{0};
uint32_t sharedExpertRankNum_{0};
uint32_t quantMode_{0};
uint32_t globalBs_{0};
uint32_t bs_{0};
uint32_t maxBs_{0};
uint32_t topK_{0};
AscendC::TPipe *tpipe_{nullptr};
__gm__ HcclOpResParam *winContext_{nullptr};
const DispatchGmmCombineDecodeTilingData *tilingData_;
};
template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16<TemplateMC2TypeFunc>::Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR expertTokenNums,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
tpipe_ = pipe;
blockDim_ = AscendC::GetBlockNum();
winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
gmSmoothScales_ = expert_smooth_scales; // not used now
gmX_ = x; // input token
gmexpertIds_ = expert_ids;
gmPermuteWeight1_ = gmm1_permuted_weight;
gmPermuteScale1_ = nullptr;
gmWeight2_ = gmm2_weight;
gmScale2_ = nullptr;
gmOutput_ = output;
gmExpertTokenNums_ = expertTokenNums;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
xActiveMask_ = x_active_mask;
tilingData_ = tilingData;
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum;
sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode;
globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
maxBs_ = globalBs_ / epRankSize_;
bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
if (isShareExpert) {
maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
} else {
maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
}
gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
gmm2OutputDim_ = tokenHiddenSize_;
gmm2InputDim_ = gmm1OutputDim_ / 2;
}
template<uint32_t EXEC_FLAG, typename WType>
__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) {
if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) {
MatrixCoord mc{k, n};
return layout::RowMajor::template MakeLayoutInUb<WType>(mc);
} else {
return layout::zN::template MakeLayout<WType>(k, n);
}
}
template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16<TemplateMC2TypeFunc>::Process()
{
using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type;
GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_};
GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_};
layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_};
auto layoutWeight1 = CreateWeightLayout<EXEC_FLAG, WType>(tokenHiddenSize_, gmm1OutputDim_);
layout::VectorLayout layoutW1Scale{gmm1OutputDim_};
layout::VectorLayout layoutX1Scale{maxTokenNum_};
layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_};
auto layoutWeight2 = CreateWeightLayout<EXEC_FLAG, WType>(gmm2InputDim_, gmm2OutputDim_);
layout::VectorLayout layoutW2Scale{gmm2OutputDim_};
layout::VectorLayout layoutX2Scale{maxTokenNum_};
layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_};
size_t workspaceOffset = 0;
constexpr int32_t resveredWorkSpaceSize = 256 * 1024;
int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType);
int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(ExpandXType);
int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
GM_ADDR gmX1 = workspaceGM_ + workspaceOffset;
GM_ADDR gmX2 = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxTokenSize);
GM_ADDR gmX1Scale = nullptr;
GM_ADDR gmX2Scale = nullptr;
GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset;
GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(blockDim_) * (FP16_BF16_L1M * FP16_BF16_L1N) *
WORKSPACE_STAGES * sizeof(float));
int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float);
int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType);
int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize;
GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset;
GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(maxSwigluGmm2Size);
GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(groupCount_) * sizeof(int64_t));
GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(bs_) * topK_ * sizeof(int32_t));
GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(static_cast<size_t>(epRankSize_) * groupCount_ * sizeof(int32_t));
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
if constexpr (g_coreType == AscendC::AIV) {
AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, ExpandXType, false, false, false, false, EXEC_FLAG>
dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process();
tpipe.Destroy();
icache_preload(8);
}
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};
if constexpr (g_coreType == AscendC::AIV) {
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
} else {
Arch::CrossCoreWaitFlag(gmm1AivFinished);
}
}
GmmDeqSwigluQuant<TemplateMC2TypeFunc, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};
if constexpr (g_coreType == AscendC::AIV) {
Arch::CrossCoreBarrier<0x0, PIPE_MTE3>();
Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished);
} else {
Arch::CrossCoreWaitFlag(gmm1AivFinished);
}
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
if (g_coreType == AscendC::AIV) {
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
workspaceGM_, nullptr, tilingData_);
}
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
Gmm2DispatchPolicy>(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2,
gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut,
layoutOutput, gmWorkspace, &combiner);
}
} // namespace DispatchGmmCombineDecodeBf16Fp16Impl
#endif // DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H

View File

@@ -31,6 +31,8 @@ struct DispatchGmmCombineDecodeInfo {
uint64_t totalWinSize;
uint64_t gmm1HLen;
bool isTensorList;
bool isBf16Fp16W;
bool isNDFormat;
};
struct DispatchGmmCombineDecodeTilingData {
@@ -48,6 +50,8 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1;
constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true;
constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true;
constexpr uint32_t FP16_BF16_L1M = 128;
constexpr uint32_t FP16_BF16_L1N = 128;
constexpr uint32_t GMM1_L1M = 256;
constexpr uint32_t GMM1_L1N = 128;
constexpr uint32_t GMM1_L1K = 512;
@@ -56,6 +60,8 @@ constexpr uint32_t GMM1_EPIM = 64;
constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3;
constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0;
constexpr uint32_t FP16_BF16_GMM2_L1M = 64;
constexpr uint32_t FP16_BF16_GMM2_L1N = 128;
constexpr uint32_t GMM2_L1A_STAGES = 4;
constexpr uint32_t GMM2_L1B_STAGES = 2;
constexpr uint32_t GMM2_L0A_STAGES = 4;
@@ -73,5 +79,6 @@ constexpr uint32_t WORKSPACE_STAGES = 4;
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1);
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
constexpr uint32_t EXEC_FLAG_ND_FORMAT = (1U << 3);
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H