cherry-pick https://github.com/vllm-project/vllm-ascend/pull/8539 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Due to end-to-end testing , three optimization points for the decode scenario have been reverted in dispatch_ffn_combine kernel. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: l00893928 <liuquanlu@huawei.com> Co-authored-by: l00893928 <liuquanlu@huawei.com>
283 lines
10 KiB
C++
283 lines
10 KiB
C++
/**
|
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
* This file is a part of the CANN Open Software.
|
|
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
|
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
* See LICENSE in the root of the software repository for the full text of the License.
|
|
*/
|
|
|
|
/*!
|
|
* \file dispatch_ffn_combine.h
|
|
* \brief
|
|
*/
|
|
|
|
#ifndef DISPATCH_FFN_COMBINE_H
|
|
#define DISPATCH_FFN_COMBINE_H
|
|
|
|
using namespace AscendC;
|
|
|
|
#include "kernel_operator.h"
|
|
|
|
#include "utils/moe_distribute_base.h"
|
|
|
|
#include "dispatch_ffn_combine_tiling.h"
|
|
|
|
#include "catlass/catlass.hpp"
|
|
#include "catlass/arch/arch.hpp"
|
|
#include "catlass/epilogue/dispatch_policy.hpp"
|
|
#include "catlass/epilogue/block/block_epilogue.hpp"
|
|
#include "catlass/epilogue/tile/tile_copy.hpp"
|
|
#include "catlass/epilogue/tile/tile_elemwise_add.hpp"
|
|
#include "catlass/epilogue/tile/tile_elemwise_muls.hpp"
|
|
#include "catlass/gemm/block/block_mmad.hpp"
|
|
#include "catlass/gemm/block/block_swizzle.hpp"
|
|
#include "catlass/gemm/dispatch_policy.hpp"
|
|
#include "catlass/gemm/kernel/matmul_epilogue.hpp"
|
|
#include "catlass/gemm/gemm_type.hpp"
|
|
#include "catlass/layout/layout.hpp"
|
|
|
|
#include "utils/select_helper.hpp"
|
|
#include "utils/const_args.hpp"
|
|
#include "dispatch_ffn_combine_kernel.hpp"
|
|
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
|
|
|
using namespace Catlass;
|
|
|
|
namespace DispatchFFNCombineImpl {
|
|
#define TemplateMMA2AClass typename AType_, typename BType_, typename CType_, bool TB_, bool Nz_
|
|
#define TemplateMMA2ACFunc AType_, BType_, CType_, TB_, Nz_
|
|
|
|
using namespace AscendC;
|
|
template <TemplateMMA2AClass>
|
|
class DispatchFFNCombine {
|
|
public:
|
|
__aicore__ inline DispatchFFNCombine() {};
|
|
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
|
__aicore__ inline void Process();
|
|
|
|
|
|
private:
|
|
GM_ADDR xGM_;
|
|
GM_ADDR weight1GM_;
|
|
GM_ADDR weight2GM_;
|
|
GM_ADDR expertIdGM_;
|
|
GM_ADDR scale1GM_;
|
|
GM_ADDR scale2GM_;
|
|
GM_ADDR probs_;
|
|
GM_ADDR outGM_;
|
|
GM_ADDR gmExpertTokenNums_;
|
|
GM_ADDR workspaceGM_;
|
|
|
|
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
|
GM_ADDR moeInitRoutingQuantV2Offset = nullptr;
|
|
GM_ADDR expertTokensBeforeCapacity = nullptr;
|
|
|
|
|
|
TBuf<AscendC::TPosition::VECCALC> uBuf_;
|
|
|
|
int32_t rank;
|
|
int32_t rankSize;
|
|
int32_t aivNum;
|
|
|
|
int32_t m0;
|
|
int32_t k0;
|
|
int32_t n0;
|
|
int32_t swizzlOffset;
|
|
int32_t swizzlDirect;
|
|
int32_t ubMoveNum;
|
|
int32_t pValue;
|
|
|
|
int32_t commNpuSplit;
|
|
int32_t commDataSplit;
|
|
int32_t lenPerLoop;
|
|
|
|
int32_t m;
|
|
int32_t k;
|
|
int32_t n;
|
|
int32_t topK;
|
|
int32_t expertPerRank;
|
|
int32_t maxOutputSize;
|
|
int32_t EP;
|
|
int32_t listLen;
|
|
|
|
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
|
uint64_t initRoutingQuantTilingKey;
|
|
|
|
// Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
|
|
|
};
|
|
|
|
|
|
template <TemplateMMA2AClass>
|
|
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
|
{
|
|
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
|
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
|
|
GET_TILING_DATA(tilingData, tilingGM);
|
|
|
|
xGM_ = xGM;
|
|
weight1GM_ = weight1GM;
|
|
weight2GM_ = weight2GM;
|
|
expertIdGM_ = expertIdGM;
|
|
scale1GM_ = scale1GM;
|
|
scale2GM_ = scale2GM;
|
|
probs_ = probs;
|
|
|
|
outGM_ = outGM;
|
|
gmExpertTokenNums_ = expertTokenNums;
|
|
|
|
workspaceGM_ = workspaceGM;
|
|
|
|
aivNum = tilingData.dispatchFFNCombineInfo.aivNum;
|
|
|
|
m = tilingData.dispatchFFNCombineInfo.M;
|
|
k = tilingData.dispatchFFNCombineInfo.K;
|
|
n = tilingData.dispatchFFNCombineInfo.N;
|
|
EP = tilingData.dispatchFFNCombineInfo.worldSize;
|
|
topK = tilingData.dispatchFFNCombineInfo.topK;
|
|
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
|
|
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
|
|
listLen = tilingData.dispatchFFNCombineInfo.listLen;
|
|
|
|
m0 = tilingData.cocTiling.m0;
|
|
k0 = tilingData.cocTiling.k0;
|
|
n0 = tilingData.cocTiling.n0;
|
|
swizzlDirect = tilingData.cocTiling.swizzleDirect;
|
|
swizzlOffset = tilingData.cocTiling.swizzleOffset;
|
|
ubMoveNum = tilingData.cocTiling.ubMoveNum;
|
|
pValue = tilingData.cocTiling.pValue;
|
|
commNpuSplit = tilingData.cocTiling.commNpuSplit;
|
|
commDataSplit = tilingData.cocTiling.commDataSplit;
|
|
lenPerLoop = tilingData.cocTiling.lenPerLoop;
|
|
moeInitRoutingQuantV2TilingData = tilingData.cocTiling.moeInitRoutingQuantV2TilingData;
|
|
initRoutingQuantTilingKey = tilingData.cocTiling.initRoutingQuantTilingKey;
|
|
|
|
auto contextGM0 = AscendC::GetHcclContext<HCCL_GROUP_ID_0>();
|
|
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
|
WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0;
|
|
|
|
rank = WinContext_->localUsrRankId;
|
|
rankSize = WinContext_->rankSize;
|
|
}
|
|
|
|
template <TemplateMMA2AClass>
|
|
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
|
{
|
|
// Define ArchTag
|
|
using ArchTag = Arch::AtlasA2;
|
|
constexpr bool enableUnitFlag = false;
|
|
constexpr bool enableShuffleK = true;
|
|
|
|
uint32_t k2 = n/2;
|
|
uint32_t n2 = k;
|
|
|
|
int64_t activeNum = 0;
|
|
int64_t expertCapacity = 0;
|
|
int64_t expertNum = expertPerRank * EP;
|
|
int64_t dropPadMode = 0;
|
|
int64_t expertTokensCountOrCumsumFlag = 2;
|
|
bool expertTokensBeforeCapacityFlag = false;
|
|
int64_t quantMode = 1;
|
|
|
|
using LayoutA = layout::RowMajor;
|
|
using LayoutB = typename std::conditional<
|
|
Nz_,
|
|
layout::zN,
|
|
typename std::conditional<TB_, layout::ColumnMajor, layout::RowMajor>::type
|
|
>::type;
|
|
|
|
LayoutB layoutB1 = LayoutBInitializer<LayoutB, BType_>::create(k, n);
|
|
LayoutB layoutB2 = LayoutBInitializer<LayoutB, BType_>::create(k2, n2);
|
|
using LayoutC = layout::RowMajor;
|
|
using L1TileShape = GemmShape<128, 256, 512>; // M, N, K
|
|
|
|
constexpr uint32_t workspaceStages = 2;
|
|
constexpr uint32_t preloadStages = 1;
|
|
constexpr uint32_t l1Stages = 2;
|
|
constexpr uint32_t l0AStages = 2;
|
|
constexpr uint32_t l0BStages = 2;
|
|
constexpr uint32_t l0CStages = 1;
|
|
|
|
using DispatchPolicy = Gemm::MmadAtlasA2PreloadAsyncFixpipe<
|
|
preloadStages,
|
|
l1Stages, l0AStages, l0BStages, l0CStages,
|
|
enableUnitFlag, enableShuffleK
|
|
>;
|
|
|
|
using L0TileShape = GemmShape<128, 256, 128>;
|
|
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
|
using BType = Gemm::GemmType<int8_t, LayoutB>;
|
|
using CType = Gemm::GemmType<float16_t, layout::RowMajor>;
|
|
using D1Type = Gemm::GemmType<int8_t, layout::RowMajor>;
|
|
|
|
using D2Type = typename std::conditional<
|
|
std::is_same_v<CType_, bfloat16_t>,
|
|
Gemm::GemmType<bfloat16_t, layout::RowMajor>,
|
|
Gemm::GemmType<CType_, layout::RowMajor>
|
|
>::type;
|
|
|
|
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
|
|
constexpr uint32_t ubStages = 2;
|
|
|
|
using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant<ubStages>;
|
|
|
|
using ScaleType = Gemm::GemmType<uint64_t, layout::VectorLayout>;
|
|
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
|
using ElementMulType = Gemm::GemmType<float, layout::RowMajor>;
|
|
using TileElemWiseMuls = Epilogue::Tile::TileElemWiseMuls<ArchTag, ElementMulType, 0>;
|
|
|
|
using TileCopy1 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D1Type>;
|
|
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
|
|
D1Type, TileElemWiseMuls, TileCopy1>;
|
|
|
|
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
|
|
|
|
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
|
|
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
|
|
D2Type, TileCopy2>;
|
|
|
|
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<9, 1>;
|
|
using ElementGroupList = int64_t;
|
|
using MatmulKernel = Gemm::Kernel::DispatchFFNCombineKernel<BlockMmad,
|
|
BlockScheduler, ElementGroupList, BlockEpilogue1, BlockEpilogue2>;
|
|
|
|
LayoutA layoutA1{static_cast<uint32_t>(m), static_cast<uint32_t>(k)};
|
|
LayoutA layoutA2{static_cast<uint32_t>(m), static_cast<uint32_t>(k2)};
|
|
layout::VectorLayout layoutScale1{static_cast<uint32_t>(n)};
|
|
layout::VectorLayout layoutScale2{static_cast<uint32_t>(n2)};
|
|
layout::RowMajor layoutD1{static_cast<uint32_t>(maxOutputSize), static_cast<uint32_t>(k2)};
|
|
layout::RowMajor layoutD2{static_cast<uint32_t>(m*topK), static_cast<uint32_t>(n2)};
|
|
// Prepare params
|
|
|
|
GemmCoord problemShape{static_cast<uint32_t>(m), static_cast<uint32_t>(n), static_cast<uint32_t>(k)};
|
|
|
|
uint32_t epilogueCoreNum = aivNum;
|
|
uint32_t epilogueGranularity = expertPerRank - 3;
|
|
if (expertPerRank <= 4) {
|
|
epilogueGranularity = expertPerRank - 1;
|
|
}
|
|
typename MatmulKernel::Params params{
|
|
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(listLen), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
|
|
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
|
|
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
|
|
epilogueCoreNum, epilogueGranularity,
|
|
xGM_, layoutA1, layoutA2,
|
|
weight1GM_, layoutB1,
|
|
weight2GM_, layoutB2,
|
|
scale1GM_, layoutScale1,
|
|
scale2GM_, layoutScale2,
|
|
outGM_, layoutD1, layoutD2,
|
|
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
|
expertTokensBeforeCapacity, probs_,
|
|
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
|
//Call kernel
|
|
MatmulKernel kernel(params);
|
|
kernel(params);
|
|
}
|
|
|
|
} // DispatchFFNCombineImpl
|
|
#endif // DISPATCH_FFN_COMBINE_H
|