[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-19 16:31:48 +08:00
committed by GitHub
parent 636265be6d
commit 118b0ed346
14 changed files with 292 additions and 96 deletions

View File

@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}

View File

@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,

View File

@@ -45,14 +45,19 @@ public:
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
this->Input("x_active_mask")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.DataType({ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")

View File

@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t ONE_DIMS = 1;
constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS;
}
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
context->SetTilingKey(0);
} else {
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
uint64_t tilingKey = 0;
if (xActiveMaskEnable) {
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;
}