[Refactor] Add expert processed token count output for DispatchFFNCombine/DispatchFFNCombineBF16 (#6402)

### What this PR does / why we need it?
Add New Output for Expert Token Count
An additional output tensor expert_token_nums is added to both operators
to meet the requirement of tracking token distribution among experts:

Tensor Name: expert_token_nums
Dimension: 1D tensor
Shape: (local_expert_num,)
Data Type: int32
Semantics: Represents the number of tokens actually received by each
expert on the current card.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: guanguan0308 <1546542263@qq.com>
Signed-off-by: guanguan0308 <162653673+guanguan0308@users.noreply.github.com>
This commit is contained in:
guanguan0308
2026-02-03 10:41:06 +08:00
committed by GitHub
parent 26b83f8bde
commit dffac6db73
18 changed files with 97 additions and 84 deletions

View File

@@ -20,32 +20,32 @@
using namespace AscendC;
using namespace DispatchFFNCombineBF16Impl;
extern "C" __global__ __aicore__ void dispatch_ffn_combine_bf16(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
if (TILING_KEY_IS(1000000)) {
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000001)) {
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, false> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000011)) {
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
}
}

View File

@@ -57,7 +57,7 @@ class DispatchFFNCombineBF16 {
public:
__aicore__ inline DispatchFFNCombineBF16() {};
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
__aicore__ inline void Process();
@@ -70,6 +70,7 @@ private:
GM_ADDR scale2GM_;
GM_ADDR probs_;
GM_ADDR outGM_;
GM_ADDR gmExpertTokenNums_;
GM_ADDR workspaceGM_;
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
@@ -114,7 +115,7 @@ private:
template <TemplateMMA2AClass>
__aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
auto tiling = (__gm__ DispatchFFNCombineBF16TilingData*)tilingGM;
@@ -129,6 +130,7 @@ __aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Init(GM_ADDR
probs_ = probs;
outGM_ = outGM;
gmExpertTokenNums_ = expertTokenNums;
workspaceGM_ = workspaceGM;
@@ -278,7 +280,7 @@ __aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Process()
outGM_, layoutD1, layoutD2,
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
expertTokensBeforeCapacity, probs_,
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
//Call kernel
MatmulKernel kernel(params);
kernel(params);

View File

@@ -97,6 +97,7 @@ public:
LayoutD1 layoutD1;
LayoutD2 layoutD2;
GM_ADDR ptrWorkspace;
GM_ADDR ptrExpertTokenNums;
int32_t EP;
int32_t listLen;
int32_t expertPerRank;
@@ -141,7 +142,7 @@ public:
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
GM_ADDR moeInitRoutingQuantV2Offset_,
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_,
optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData_,
GM_ADDR symmetricPtr_ = nullptr
) : problemShape(problemShape_),
@@ -158,7 +159,7 @@ public:
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_),
ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_),
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
{
moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingData_.vbsComputeParamsOp;
@@ -518,14 +519,6 @@ CATLASS_DEVICE
int64_t preCurrentmSum = 0;
int32_t syncLoopIdx = -1;
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
}
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -538,8 +531,8 @@ CATLASS_DEVICE
AscendC::GlobalTensor<ElementB> gmB1;
AscendC::GlobalTensor<ElementScale> gmS;
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
AscendC::PipeBarrier<PIPE_ALL>();
@@ -647,13 +640,6 @@ CATLASS_DEVICE
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
}
AscendC::PipeBarrier<PIPE_ALL>();
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -667,8 +653,8 @@ CATLASS_DEVICE
AscendC::GlobalTensor<ElementScale> gmS2;
AscendC::PipeBarrier<PIPE_ALL>();
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
if (currentM <= L1TileShape::M) {
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
@@ -766,6 +752,13 @@ CATLASS_DEVICE
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
syncgmm1Idx++;
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
AscendC::SyncAll<true>();
uint32_t curGroupOffset = 0;
int32_t prevSumBeforeRank = 0;
int32_t groupIdxDeq = 0;
@@ -776,7 +769,8 @@ CATLASS_DEVICE
uint32_t prevGroupSum1 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
BlockEpilogue1 blockEpilogue(resource);
uint32_t n = params.problemShape.n();
BlockEpilogue1 blockEpilogue(resource, n);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;

View File

@@ -85,16 +85,16 @@ public:
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
BlockEpilogue(Arch::Resource<ArchTag> const &resource, int32_t n, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
constexpr uint32_t blockN = 4096;
constexpr uint32_t ChunkTileLen = blockN / 2;
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
uint32_t blockN = n;
uint32_t ChunkTileLen = blockN / 2;
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);