[Refactor] Add expert processed token count output for DispatchFFNCombine/DispatchFFNCombineBF16 (#6402)
### What this PR does / why we need it?
Add New Output for Expert Token Count
An additional output tensor expert_token_nums is added to both operators
to meet the requirement of tracking token distribution among experts:
Tensor Name: expert_token_nums
Dimension: 1D tensor
Shape: (local_expert_num,)
Data Type: int32
Semantics: Represents the number of tokens actually received by each
expert on the current card.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
Signed-off-by: guanguan0308 <162653673+guanguan0308@users.noreply.github.com>
This commit is contained in:
@@ -20,32 +20,32 @@
|
||||
using namespace AscendC;
|
||||
using namespace DispatchFFNCombineImpl;
|
||||
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
||||
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||
if (TILING_KEY_IS(1000000)) {
|
||||
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000001)) {
|
||||
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000010)) {
|
||||
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
} else if (TILING_KEY_IS(1000011)) {
|
||||
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ class DispatchFFNCombine {
|
||||
public:
|
||||
__aicore__ inline DispatchFFNCombine() {};
|
||||
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
||||
__aicore__ inline void Process();
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ private:
|
||||
GM_ADDR scale2GM_;
|
||||
GM_ADDR probs_;
|
||||
GM_ADDR outGM_;
|
||||
GM_ADDR gmExpertTokenNums_;
|
||||
GM_ADDR workspaceGM_;
|
||||
|
||||
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
||||
@@ -112,7 +113,7 @@ private:
|
||||
|
||||
template <TemplateMMA2AClass>
|
||||
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||
{
|
||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
|
||||
@@ -127,6 +128,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM,
|
||||
probs_ = probs;
|
||||
|
||||
outGM_ = outGM;
|
||||
gmExpertTokenNums_ = expertTokenNums;
|
||||
|
||||
workspaceGM_ = workspaceGM;
|
||||
|
||||
@@ -268,7 +270,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
||||
outGM_, layoutD1, layoutD2,
|
||||
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
||||
expertTokensBeforeCapacity, probs_,
|
||||
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
||||
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
||||
//Call kernel
|
||||
MatmulKernel kernel(params);
|
||||
kernel(params);
|
||||
|
||||
@@ -96,6 +96,7 @@ public:
|
||||
LayoutD1 layoutD1;
|
||||
LayoutD2 layoutD2;
|
||||
GM_ADDR ptrWorkspace;
|
||||
GM_ADDR ptrExpertTokenNums;
|
||||
int32_t EP;
|
||||
int32_t listLen;
|
||||
int32_t expertPerRank;
|
||||
@@ -139,7 +140,7 @@ public:
|
||||
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
||||
GM_ADDR moeInitRoutingQuantV2Offset_,
|
||||
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
||||
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
||||
GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_,
|
||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
||||
) : problemShape(problemShape_),
|
||||
EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
||||
@@ -155,7 +156,7 @@ public:
|
||||
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
||||
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
||||
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
||||
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),
|
||||
ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),
|
||||
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
||||
{
|
||||
}
|
||||
@@ -228,7 +229,7 @@ private:
|
||||
|
||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||
|
||||
tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -335,15 +336,6 @@ private:
|
||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||
syncgmmIdx++;
|
||||
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
||||
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
||||
|
||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
||||
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
|
||||
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
@@ -356,8 +348,8 @@ private:
|
||||
AscendC::GlobalTensor<ElementB> gmB1;
|
||||
AscendC::GlobalTensor<ElementScale> gmS;
|
||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
|
||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
|
||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
@@ -455,14 +447,6 @@ private:
|
||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||
}
|
||||
|
||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 8;
|
||||
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
|
||||
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
|
||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
||||
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
|
||||
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
|
||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||
@@ -476,8 +460,8 @@ private:
|
||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
|
||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
|
||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
||||
|
||||
if (currentM <= L1TileShape::M) {
|
||||
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||
@@ -596,7 +580,6 @@ private:
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||
}
|
||||
|
||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||
if (dstEpIdx == params.rank) {
|
||||
continue;
|
||||
@@ -639,6 +622,13 @@ private:
|
||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||
}
|
||||
AscendC::SyncAll<true>();
|
||||
|
||||
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
|
||||
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
|
||||
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
|
||||
AscendC::SyncAll<true>();
|
||||
uint16_t syncgmm1Idx = 0;
|
||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||
syncgmm1Idx++;
|
||||
|
||||
Reference in New Issue
Block a user