[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)
#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.
This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .
- vLLM version: v0.13.0
- vLLM main:
7157596103
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
@@ -17,7 +17,7 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||
GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
GM_ADDR output, GM_ADDR expertTokenNums,
|
||||
// system
|
||||
GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
@@ -25,10 +25,11 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
|
||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
|
||||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
|
||||
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
||||
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
||||
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
||||
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
|
||||
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
|
||||
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
||||
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums,
|
||||
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
|
||||
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
|
||||
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
|
||||
@@ -138,7 +138,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
gmEpSendCount,
|
||||
xActiveMask,
|
||||
gmResvered,
|
||||
gmOutputRecvCount,
|
||||
gmExpertTokenNums,
|
||||
epRankSize,
|
||||
epRankId,
|
||||
moeExpertNum,
|
||||
@@ -244,7 +244,7 @@ public:
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
GM_ADDR output, GM_ADDR expertTokenNums,
|
||||
// system
|
||||
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
@@ -257,7 +257,7 @@ private:
|
||||
GM_ADDR gmWeight2_;
|
||||
GM_ADDR gmScale2_;
|
||||
GM_ADDR gmOutput_;
|
||||
GM_ADDR gmOutputRecvCount_;
|
||||
GM_ADDR gmExpertTokenNums_;
|
||||
GM_ADDR workspaceGM_;
|
||||
GM_ADDR gmSmoothScales_;
|
||||
GM_ADDR gmexpertScales_;
|
||||
@@ -296,7 +296,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||
GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
GM_ADDR output, GM_ADDR expertTokenNums,
|
||||
// system
|
||||
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
@@ -312,7 +312,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
gmWeight2_ = gmm2_weight;
|
||||
gmScale2_ = gmm2_weight_scale;
|
||||
gmOutput_ = output;
|
||||
gmOutputRecvCount_ = outputRecvCount;
|
||||
gmExpertTokenNums_ = expertTokenNums;
|
||||
workspaceGM_ = workspaceGM;
|
||||
gmexpertScales_ = expert_scales;
|
||||
xActiveMask_ = x_active_mask;
|
||||
@@ -396,7 +396,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
|
||||
dispatcher;
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||
gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||
dispatcher.Process();
|
||||
tpipe.Destroy();
|
||||
icache_preload(8);
|
||||
@@ -416,7 +416,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
||||
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
|
||||
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||
gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
Arch::CrossCoreFlag gmm1AivFinished{0};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
|
||||
#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
|
||||
|
||||
#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h"
|
||||
#include "../../raw_distributed/cam_moe_distribute_combine.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/cross_core_sync.hpp"
|
||||
@@ -122,7 +123,10 @@ public:
|
||||
AscendC::GlobalTensor<ElementA> gmA;
|
||||
gmA.SetGlobalBuffer(params.ptrA);
|
||||
AscendC::GlobalTensor<ElementB> gmB;
|
||||
gmB.SetGlobalBuffer(params.ptrB);
|
||||
AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB));
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr<int32_t>(0)));
|
||||
}
|
||||
AscendC::GlobalTensor<ElementGroupList> groupList;
|
||||
groupList.SetGlobalBuffer(params.ptrGroupList);
|
||||
|
||||
@@ -139,6 +143,10 @@ public:
|
||||
uint32_t stageUsed = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
|
||||
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(
|
||||
gmBlistTensorDesc.GetDataPtr<int32_t>(groupIdx)));
|
||||
}
|
||||
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
|
||||
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
|
||||
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
|
||||
@@ -189,7 +197,9 @@ public:
|
||||
}
|
||||
|
||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
}
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||
}
|
||||
|
||||
@@ -232,6 +242,12 @@ public:
|
||||
|
||||
uint32_t stageId = 0;
|
||||
uint32_t startCoreIdx = 0;
|
||||
AscendC::ListTensorDesc gmScaleListTensor;
|
||||
gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale));
|
||||
__gm__ ElementScale* gmScalePtr;
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr<int32_t>(0));
|
||||
}
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
|
||||
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
|
||||
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
|
||||
@@ -241,12 +257,22 @@ public:
|
||||
LayoutPerTokenScale layoutPerTokenScale =
|
||||
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
|
||||
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());
|
||||
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
|
||||
EpilogueParams epilogueParams;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
|
||||
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(
|
||||
gmScaleListTensor.GetDataPtr<int32_t>(groupIdx));
|
||||
epilogueParams = EpilogueParams {
|
||||
gmScalePtr, layoutScale,
|
||||
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale,
|
||||
params.ptrD + gmGroupOffsetD, layoutD};
|
||||
} else {
|
||||
epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
params.ptrD + gmGroupOffsetD,
|
||||
layoutD};
|
||||
}
|
||||
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
|
||||
blockEpilogue.UpdateParams(epilogueParams);
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
@@ -270,7 +296,9 @@ public:
|
||||
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
|
||||
}
|
||||
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
}
|
||||
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
|
||||
gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h"
|
||||
#include "catlass/catlass.hpp"
|
||||
#include "catlass/arch/cross_core_sync.hpp"
|
||||
#include "catlass/arch/resource.hpp"
|
||||
@@ -416,7 +417,7 @@ public:
|
||||
GM_ADDR gmExpandIdx;
|
||||
GM_ADDR gmEpSendCount;
|
||||
GM_ADDR gmResvered;
|
||||
GM_ADDR gmOutputRecvCount;
|
||||
GM_ADDR gmExpertTokenNums;
|
||||
|
||||
uint32_t epRankSize;
|
||||
uint32_t epRankId;
|
||||
@@ -440,7 +441,7 @@ public:
|
||||
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
|
||||
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
|
||||
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
|
||||
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
|
||||
GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_,
|
||||
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
|
||||
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
|
||||
uint32_t h)
|
||||
@@ -465,7 +466,7 @@ public:
|
||||
gmexpertIds(gmexpertIds_),
|
||||
gmExpandIdx(gmExpandIdx_),
|
||||
gmEpSendCount(gmEpSendCount_),
|
||||
gmOutputRecvCount(gmOutputRecvCount_),
|
||||
gmExpertTokenNums(gmExpertTokenNums_),
|
||||
gmXActiveMask(gmXActiveMask_),
|
||||
gmResvered(gmResvered_),
|
||||
epRankSize(epRankSize_),
|
||||
@@ -535,7 +536,10 @@ public:
|
||||
AscendC::GlobalTensor<ElementA> gmA;
|
||||
gmA.SetGlobalBuffer(params.ptrA);
|
||||
AscendC::GlobalTensor<ElementB> gmB;
|
||||
gmB.SetGlobalBuffer(params.ptrB);
|
||||
AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB));
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr<int32_t>(0)));
|
||||
}
|
||||
|
||||
AscendC::GlobalTensor<ElementGroupList> groupList;
|
||||
groupList.SetGlobalBuffer(params.ptrGroupList);
|
||||
@@ -555,6 +559,10 @@ public:
|
||||
static_cast<uint8_t>(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part
|
||||
uint32_t target = 1;
|
||||
for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) {
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
|
||||
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(
|
||||
gmBlistTensorDesc.GetDataPtr<int32_t>(groupIdx)));
|
||||
}
|
||||
groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) +
|
||||
groupIdx * GROUP_INFO_SIZE);
|
||||
// wait AIV recv needed tokens
|
||||
@@ -619,7 +627,9 @@ public:
|
||||
}
|
||||
|
||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||
}
|
||||
|
||||
startCoreIdx = (startCoreIdx + coreLoops) % aicNum;
|
||||
}
|
||||
@@ -1087,7 +1097,7 @@ public:
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount)
|
||||
void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset)
|
||||
{
|
||||
// calculate token index in output tensor
|
||||
int64_t subUbOffset = ubOffset;
|
||||
@@ -1113,15 +1123,6 @@ public:
|
||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||
AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM,
|
||||
{1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt);
|
||||
if (isRecvCore && recvCoreIdx == 0) {
|
||||
AscendC::GlobalTensor<int32_t> recvCountTensor;
|
||||
recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount);
|
||||
AscendC::DataCopyExtParams dataCopyParams = {
|
||||
1U, static_cast<uint32_t>(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U};
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(0);
|
||||
AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast<int32_t>(), dataCopyParams);
|
||||
}
|
||||
AscendC::LocalTensor<float> workLocalTensor = resource.ubBuf.template GetBufferByByte<float>(subUbOffset);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::ReduceSum<float>(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor,
|
||||
@@ -1222,7 +1223,7 @@ public:
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount)
|
||||
void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount)
|
||||
{
|
||||
ubOffset = 0;
|
||||
RecvCount(ubOffset);
|
||||
@@ -1249,7 +1250,7 @@ public:
|
||||
|
||||
if (startRankId < recvExpertNum) {
|
||||
// RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info
|
||||
GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount);
|
||||
GetCumSum(startRankId, recvExpertNum, ubOffset);
|
||||
RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset);
|
||||
}
|
||||
|
||||
@@ -1291,8 +1292,13 @@ public:
|
||||
uint32_t stageId = 0;
|
||||
uint32_t target = 1;
|
||||
uint32_t startCoreIdx = 0;
|
||||
|
||||
AscendC::ListTensorDesc gmScaleListTensor;
|
||||
AscendC::GlobalTensor<int32_t> groupTokenNumStateTensor;
|
||||
gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale));
|
||||
__gm__ ElementScale* gmScalePtr;
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr<int32_t>(0));
|
||||
}
|
||||
for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) {
|
||||
// just like AIC
|
||||
groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) +
|
||||
@@ -1311,12 +1317,22 @@ public:
|
||||
LayoutPerTokenScale layoutPerTokenScale =
|
||||
wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
|
||||
LayoutD layoutD = layout::RowMajor{currentM, n};
|
||||
EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
gmTokenScale + gmGroupOffsetPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
gmSwigluOutput + gmGroupOffsetD,
|
||||
layoutD};
|
||||
EpilogueParams epilogueParams;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
|
||||
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(
|
||||
gmScaleListTensor.GetDataPtr<int32_t>(groupIdx));
|
||||
epilogueParams = EpilogueParams {
|
||||
gmScalePtr, layoutScale,
|
||||
gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale,
|
||||
gmSwigluOutput + gmGroupOffsetD, layoutD};
|
||||
} else {
|
||||
epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale,
|
||||
layoutScale,
|
||||
gmTokenScale + gmGroupOffsetPerTokenScale,
|
||||
layoutPerTokenScale,
|
||||
gmSwigluOutput + gmGroupOffsetD,
|
||||
layoutD};
|
||||
}
|
||||
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
|
||||
blockEpilogue.UpdateParams(epilogueParams);
|
||||
uint32_t coreLoops = blockScheduler.GetCoreLoops();
|
||||
@@ -1340,7 +1356,9 @@ public:
|
||||
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
|
||||
}
|
||||
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
|
||||
gmGroupOffsetScale += inGroupProblemShape.n();
|
||||
}
|
||||
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
|
||||
gmGroupOffsetD += currentM * n;
|
||||
|
||||
@@ -1485,7 +1503,7 @@ public:
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount)
|
||||
void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums)
|
||||
{
|
||||
if (aivIdx == aiCoreGroupNum * subBlockNum - 1) {
|
||||
// clean
|
||||
@@ -1504,19 +1522,32 @@ public:
|
||||
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList));
|
||||
AscendC::GlobalTensor<int32_t> sendCountsGlobal;
|
||||
sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount));
|
||||
AscendC::GlobalTensor<int64_t> nonCumSumExpertTokenNumsTensor;
|
||||
nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums);
|
||||
uint32_t tmpTokenNum = 0;
|
||||
for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) {
|
||||
__asm__ __volatile__("");
|
||||
AscendC::DataCacheCleanAndInvalid<int32_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
|
||||
AscendC::DcciDst::CACHELINE_OUT>(
|
||||
sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]);
|
||||
__asm__ __volatile__("");
|
||||
|
||||
uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1);
|
||||
expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum);
|
||||
uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum;
|
||||
nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum);
|
||||
tmpTokenNum = tokenNum;
|
||||
|
||||
__asm__ __volatile__("");
|
||||
AscendC::DataCacheCleanAndInvalid<int64_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
|
||||
AscendC::DcciDst::CACHELINE_OUT>(
|
||||
expertTokenNumsOutGMTensor_[localMoeIndex]);
|
||||
__asm__ __volatile__("");
|
||||
__asm__ __volatile__("");
|
||||
AscendC::DataCacheCleanAndInvalid<int64_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
|
||||
AscendC::DcciDst::CACHELINE_OUT>(
|
||||
nonCumSumExpertTokenNumsTensor[localMoeIndex]);
|
||||
__asm__ __volatile__("");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1531,8 +1562,7 @@ public:
|
||||
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
|
||||
}
|
||||
if (isRecvCore) {
|
||||
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
|
||||
(GM_ADDR)params.gmOutputRecvCount);
|
||||
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount);
|
||||
}
|
||||
|
||||
auto gmSwigluOutput = reinterpret_cast<__gm__ float *>(
|
||||
@@ -1547,7 +1577,7 @@ public:
|
||||
AscendC::SyncAll<false>();
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount);
|
||||
UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums);
|
||||
{
|
||||
// dynamic quant
|
||||
AscendC::GlobalTensor<int32_t> sendCountsGlobal;
|
||||
|
||||
@@ -30,6 +30,7 @@ struct DispatchGmmCombineDecodeInfo {
|
||||
uint64_t totalUbSize;
|
||||
uint64_t totalWinSize;
|
||||
uint64_t gmm1HLen;
|
||||
bool isTensorList;
|
||||
};
|
||||
|
||||
struct DispatchGmmCombineDecodeTilingData {
|
||||
@@ -70,6 +71,7 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
|
||||
constexpr uint32_t WORKSPACE_STAGES = 4;
|
||||
|
||||
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
|
||||
constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1);
|
||||
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
|
||||
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||
|
||||
Reference in New Issue
Block a user