[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)

### What this PR does / why we need it?

This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.

Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-09 16:17:06 +08:00
committed by GitHub
parent 9a885d08d0
commit 9567e5dd8c
5 changed files with 118 additions and 142 deletions

View File

@@ -238,6 +238,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
uint64_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
if (epRankId < sharedExpertRankNum) {
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
@@ -245,20 +246,23 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
}
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t CVSwapBufferSize =
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
size_t usrSize = maxTokenSize + tokenScaleSize + CVSwapBufferSize + maxSwigluGmm2Size + groupListSize + expandIdxSize +
epSendCountSize + resveredSize;
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
return ge::GRAPH_SUCCESS;