|
|
|
|
@@ -58,6 +58,9 @@ constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
|
|
|
|
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
|
|
|
|
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
|
|
|
|
|
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
|
|
|
|
|
constexpr uint32_t TENSOR_HIDDEN_INDEX = 1;
|
|
|
|
|
constexpr uint32_t SINGLE_HIDDEN_INDEX = 2;
|
|
|
|
|
constexpr uint32_t MAX_TENSOR_COUNT = 256;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace optiling {
|
|
|
|
|
@@ -66,8 +69,176 @@ static size_t CeilUp(size_t x, size_t y)
|
|
|
|
|
return (x + y - 1) / y * y;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
|
|
|
|
|
DispatchGmmCombineDecodeTilingData &tilingData)
|
|
|
|
|
static uint32_t CountTensorListLen(gert::TilingContext *context, int descIndex)
|
|
|
|
|
{
|
|
|
|
|
int count = 0;
|
|
|
|
|
for (uint32_t i = 0; i < MAX_TENSOR_COUNT; i++) {
|
|
|
|
|
auto tensorElement = context->GetDynamicInputTensor(descIndex, i);
|
|
|
|
|
if (tensorElement == nullptr) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
count++;
|
|
|
|
|
}
|
|
|
|
|
return count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
|
|
|
|
|
{
|
|
|
|
|
const char *nodeName = context->GetNodeName();
|
|
|
|
|
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
|
|
|
|
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
|
|
|
|
uint32_t gmm1ListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_INDEX);
|
|
|
|
|
auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0);
|
|
|
|
|
auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape();
|
|
|
|
|
uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum();
|
|
|
|
|
|
|
|
|
|
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
if (gmm1ListLen > 1) { // List
|
|
|
|
|
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(gmm1ListLen != moeExpertNumPerRank,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number perRank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
|
|
|
|
|
gmm1FirstTensorElementShape.GetDim(TENSOR_HIDDEN_INDEX);
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = true;
|
|
|
|
|
} else { // Single
|
|
|
|
|
if (elementDims == 2) { // one localExpert perRank
|
|
|
|
|
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
|
|
|
|
|
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX - 1);
|
|
|
|
|
} else { // multi localExperts perRank
|
|
|
|
|
OPS_ERR_IF(moeExpertNumPerRank != gmm1FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number per rank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
|
|
|
|
|
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX);
|
|
|
|
|
}
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = false;
|
|
|
|
|
}
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context,
|
|
|
|
|
DispatchGmmCombineDecodeTilingData *tilingData)
|
|
|
|
|
{
|
|
|
|
|
const char *nodeName = context->GetNodeName();
|
|
|
|
|
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
|
|
|
|
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t gmm1ScaleListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_SCALE_INDEX);
|
|
|
|
|
auto gmm1ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_SCALE_INDEX, 0);
|
|
|
|
|
auto gmm1ScaleFirstTensorElementShape = gmm1ScaleFirstTensorElement->GetOriginShape();
|
|
|
|
|
uint32_t elementDims = gmm1ScaleFirstTensorElementShape.GetDimNum();
|
|
|
|
|
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm1WeightScale shape is invalid."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
if (gmm1ScaleListLen > 1) { // List
|
|
|
|
|
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // Single
|
|
|
|
|
if (elementDims == 1) { // one localExpert perRank
|
|
|
|
|
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // multi localExperts perRank
|
|
|
|
|
OPS_ERR_IF(moeExpertNumPerRank != gmm1ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
|
|
|
|
|
{
|
|
|
|
|
const char *nodeName = context->GetNodeName();
|
|
|
|
|
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
|
|
|
|
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
|
|
|
|
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
|
|
|
|
|
|
|
|
|
uint32_t gmm2ListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_INDEX);
|
|
|
|
|
auto gmm2FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_INDEX, 0);
|
|
|
|
|
auto gmm2FirstTensorElementShape = gmm2FirstTensorElement->GetOriginShape();
|
|
|
|
|
uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum();
|
|
|
|
|
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
if (gmm2ListLen > 1) { // List
|
|
|
|
|
OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2 does not equals to token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2 does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // Single
|
|
|
|
|
if (elementDims == 2) { // one localExpert perRank
|
|
|
|
|
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // multi localExperts perRank
|
|
|
|
|
OPS_ERR_IF(moeExpertNumPerRank != gmm2FirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight does not match local expert num perRank."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(2),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context,
|
|
|
|
|
DispatchGmmCombineDecodeTilingData *tilingData)
|
|
|
|
|
{
|
|
|
|
|
const char *nodeName = context->GetNodeName();
|
|
|
|
|
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
|
|
|
|
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
|
|
|
|
|
|
|
|
|
uint32_t gmm2ScaleListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_SCALE_INDEX);
|
|
|
|
|
auto gmm2ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_SCALE_INDEX, 0);
|
|
|
|
|
auto gmm2ScaleFirstTensorElementShape = gmm2ScaleFirstTensorElement->GetOriginShape();
|
|
|
|
|
uint32_t elementDims = gmm2ScaleFirstTensorElementShape.GetDimNum();
|
|
|
|
|
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm2WeightScale shape is invalid."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
if (gmm2ScaleListLen > 1) { // List
|
|
|
|
|
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // Single
|
|
|
|
|
if (elementDims == 1) { // one localExpert perRank
|
|
|
|
|
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
} else { // multi localExperts perRank
|
|
|
|
|
OPS_ERR_IF(moeExpertNumPerRank != gmm2ScaleFirstTensorElementShape.GetDim(0),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(1),
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ge::graphStatus CheckWeightTensorList(gert::TilingContext *context,
|
|
|
|
|
DispatchGmmCombineDecodeTilingData *tilingData)
|
|
|
|
|
{
|
|
|
|
|
if (CheckGmm1Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
|
|
|
|
|
CheckGmm1ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS &&
|
|
|
|
|
CheckGmm2Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
|
|
|
|
|
CheckGmm2ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS) {
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
return ge::GRAPH_FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ge::graphStatus CheckXActiveMaskShape(gert::TilingContext *context, const char *nodeName,
|
|
|
|
|
DispatchGmmCombineDecodeTilingData &tilingData)
|
|
|
|
|
{
|
|
|
|
|
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
|
|
|
|
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
|
|
|
|
@@ -76,55 +247,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
|
|
|
|
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
|
|
|
|
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
|
|
|
|
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
|
|
|
|
|
|
|
|
|
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
|
|
|
|
|
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
|
|
|
|
|
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
|
|
|
|
|
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
|
|
|
|
|
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
|
|
|
|
|
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
|
|
|
|
|
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
|
|
|
|
|
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
|
|
|
|
|
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
|
|
|
|
|
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
|
|
|
|
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
|
|
|
|
if (xActiveMaskStorageShape != nullptr) {
|
|
|
|
|
@@ -312,21 +435,19 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
|
|
|
|
|
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
|
|
|
|
|
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
|
|
|
|
|
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
|
|
|
|
|
OPS_ERR_IF(CheckWeightTensorList(context, tilingData) != ge::GRAPH_SUCCESS,
|
|
|
|
|
OPS_LOG_E(nodeName, "CheckWeightTensorList failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
|
|
|
|
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
|
|
|
|
|
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
|
|
|
|
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
|
|
|
|
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
|
|
|
|
return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
|
|
|
|
|
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
|
|
|
|
OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(CheckXActiveMaskShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
|
|
|
|
OPS_LOG_E(nodeName, "CheckXActiveMaskShape failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
|
|
|
|
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
|
|
|
|
SetHcommCfg(context, tilingData, groupEp);
|
|
|
|
|
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
|
|
|
|
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
|
|
|
|
@@ -338,6 +459,9 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
|
|
|
|
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
|
|
|
|
|
tilingKey |= EXEC_FLAG_DEEP_FUSE;
|
|
|
|
|
}
|
|
|
|
|
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) {
|
|
|
|
|
tilingKey |= EXEC_FLAG_TENSOR_LIST;
|
|
|
|
|
}
|
|
|
|
|
context->SetTilingKey(tilingKey);
|
|
|
|
|
context->SetBlockDim(aicNum);
|
|
|
|
|
return ge::GRAPH_SUCCESS;
|
|
|
|
|
|