[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)
### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
||||
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
|
||||
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
|
||||
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
||||
output, epRecvCount, workspaceSize, executor);
|
||||
}
|
||||
|
||||
@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
|
||||
@@ -45,14 +45,19 @@ public:
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_smooth_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
this->Input("x_active_mask")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.DataType({ge::DT_BOOL, ge::DT_BOOL})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("output")
|
||||
|
||||
@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
|
||||
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
|
||||
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
||||
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
||||
constexpr uint32_t SUPPORT_TOP_K = 12;
|
||||
constexpr uint32_t ONE_DIMS = 1;
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
||||
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
||||
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
||||
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
|
||||
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
||||
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||
if (xActiveMaskStorageShape != nullptr) {
|
||||
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
|
||||
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
|
||||
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
|
||||
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
||||
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
|
||||
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp);
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
|
||||
context->SetTilingKey(0);
|
||||
} else {
|
||||
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
|
||||
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
|
||||
uint64_t tilingKey = 0;
|
||||
if (xActiveMaskEnable) {
|
||||
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
|
||||
}
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
|
||||
tilingKey |= EXEC_FLAG_DEEP_FUSE;
|
||||
}
|
||||
context->SetTilingKey(tilingKey);
|
||||
context->SetBlockDim(aicNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user