[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)
### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.
Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -238,6 +238,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint32_t aicNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.aicNum;
|
||||
uint64_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2;
|
||||
if (epRankId < sharedExpertRankNum) {
|
||||
maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum;
|
||||
@@ -245,20 +246,23 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
|
||||
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
|
||||
}
|
||||
|
||||
size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
|
||||
size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
|
||||
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
|
||||
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
|
||||
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t CVSwapBufferSize =
|
||||
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
|
||||
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
|
||||
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
|
||||
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
|
||||
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
|
||||
size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE);
|
||||
size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE);
|
||||
size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize +
|
||||
epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize;
|
||||
size_t usrSize = maxTokenSize + tokenScaleSize + CVSwapBufferSize + maxSwigluGmm2Size + groupListSize + expandIdxSize +
|
||||
epSendCountSize + resveredSize;
|
||||
|
||||
workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
|
||||
Reference in New Issue
Block a user