[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)

#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.

This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .


- vLLM version: v0.13.0
- vLLM main:
7157596103

More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
wangyibo1005
2026-01-07 11:23:42 +08:00
committed by GitHub
parent 086c093347
commit 25baf6df09
18 changed files with 425 additions and 195 deletions

View File

@@ -26,10 +26,10 @@ extern "C" {
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -42,7 +42,7 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
@@ -54,10 +54,10 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -70,14 +70,14 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
output, expertTokenNums, workspaceSize, executor);
}
aclnnStatus aclnnDispatchGmmCombineDecode(

View File

@@ -19,10 +19,10 @@ extern "C" {
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -35,7 +35,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor);

View File

@@ -26,22 +26,22 @@ public:
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
@@ -65,9 +65,9 @@ public:
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
this->Output("expert_token_nums")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.DataType({ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();

View File

@@ -18,7 +18,7 @@ namespace ge {
constexpr uint32_t EXPAND_X_INDEX = 0;
constexpr uint32_t EXPERT_IDS_INDEX = 1;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
constexpr uint32_t OUTPUT_EXPERT_TOKEN_NUMS = 1;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
@@ -36,9 +36,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
gert::Shape *expertTokenNumsShape = context->GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS);
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
recvCountOutShape == nullptr) {
expertTokenNumsShape == nullptr) {
return GRAPH_FAILED;
}
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
@@ -72,12 +72,12 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
recvCountOutShape->SetDimNum(1);
expertTokenNumsShape->SetDimNum(1);
bool isShareExpert = (epRankId < sharedExpertRankNum);
if (isShareExpert) {
recvCountOutShape->SetDim(0, epRankSize);
expertTokenNumsShape->SetDim(0, 1);
} else {
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
expertTokenNumsShape->SetDim(0, moeExpertNum / (epRankSize - sharedExpertRankNum));
}
return GRAPH_SUCCESS;
@@ -87,7 +87,7 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
context->SetOutputDataType(OUTPUT_EXPERT_TOKEN_NUMS, ge::DT_INT64);
return ge::GRAPH_SUCCESS;
}

View File

@@ -58,6 +58,9 @@ constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
constexpr uint32_t TENSOR_HIDDEN_INDEX = 1;
constexpr uint32_t SINGLE_HIDDEN_INDEX = 2;
constexpr uint32_t MAX_TENSOR_COUNT = 256;
} // namespace
namespace optiling {
@@ -66,8 +69,176 @@ static size_t CeilUp(size_t x, size_t y)
return (x + y - 1) / y * y;
}
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
static uint32_t CountTensorListLen(gert::TilingContext *context, int descIndex)
{
int count = 0;
for (uint32_t i = 0; i < MAX_TENSOR_COUNT; i++) {
auto tensorElement = context->GetDynamicInputTensor(descIndex, i);
if (tensorElement == nullptr) {
break;
}
count++;
}
return count;
}
static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t gmm1ListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_INDEX);
auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0);
auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm1ListLen > 1) { // List
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1ListLen != moeExpertNumPerRank,
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number perRank."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(TENSOR_HIDDEN_INDEX);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = true;
} else { // Single
if (elementDims == 2) { // one localExpert perRank
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX - 1);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number per rank."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX);
}
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = false;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t gmm1ScaleListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_SCALE_INDEX);
auto gmm1ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_SCALE_INDEX, 0);
auto gmm1ScaleFirstTensorElementShape = gmm1ScaleFirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm1ScaleFirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm1WeightScale shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm1ScaleListLen > 1) { // List
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 1) { // one localExpert perRank
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t gmm2ListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_INDEX);
auto gmm2FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_INDEX, 0);
auto gmm2FirstTensorElementShape = gmm2FirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm2ListLen > 1) { // List
OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank,
OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2 does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2 does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 2) { // one localExpert perRank
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Weight does not match local expert num perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(2),
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t gmm2ScaleListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_SCALE_INDEX);
auto gmm2ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_SCALE_INDEX, 0);
auto gmm2ScaleFirstTensorElementShape = gmm2ScaleFirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm2ScaleFirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm2WeightScale shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm2ScaleListLen > 1) { // List
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 1) { // one localExpert perRank
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckWeightTensorList(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
if (CheckGmm1Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm1ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm2Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm2ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS) {
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_FAILED;
}
ge::graphStatus CheckXActiveMaskShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
@@ -76,55 +247,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
@@ -312,21 +435,19 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
OPS_ERR_IF(CheckWeightTensorList(context, tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckWeightTensorList failed."), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckXActiveMaskShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckXActiveMaskShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
@@ -338,6 +459,9 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) {
tilingKey |= EXEC_FLAG_TENSOR_LIST;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;