[Feature] DispatchGmmCombineDecode support bf16/float16 gmm1/gmm2 weight and support gmm weight with ND format (#6393)

### What this PR does / why we need it?
1. support ND format gmm weight input.
Before this pr, gmm1_weight and gmm2_weight could only be passed as
input to the DispatchGmmCombineDecode operator in NZ data format. After
the modification, they are allowed to be passed in ND data format.
2. support bf16/float16 gmm weight
The current PR modification enables the DispatchGmmCombineDecode
operator to support non-W8A8 scenarios, allowing gmm1_weight and
gmm2_weight to be passed as float16/bfloat16 which is correspond with
input token data type.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: lih827 <383084552@qq.com>
This commit is contained in:
lih827
2026-02-12 10:37:41 +08:00
committed by GitHub
parent f1ffb5fb19
commit f71812011d
18 changed files with 3766 additions and 237 deletions

View File

@@ -19,6 +19,7 @@ add_ops_compile_options(
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
-DASCENDC_DUMP=0
${_DISPATCH_GMM_INC_OPTS}
)

View File

@@ -18,93 +18,190 @@ public:
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16,
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat(
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight_scale")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16})
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat(
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight_scale")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16,
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16})
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16,
ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT16,
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("x_active_mask")
.ParamType(OPTIONAL)
.DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16,
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16,
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("expert_token_nums")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();

View File

@@ -91,6 +91,16 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC
auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0);
auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum();
ge::DataType gmm1DataType = gmm1FirstTensorElement->GetDataType();
if (gmm1DataType == ge::DT_BF16 || gmm1DataType == ge::DT_FLOAT16) {
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = true;
} else {
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = false;
}
auto gmm1WeightDesc = context->GetDynamicInputDesc(INPUT_GMM1_WEIGHT_INDEX, 0);
if (GetPrimaryFormat(gmm1WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) {
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true;
}
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."),
return ge::GRAPH_FAILED);
@@ -129,6 +139,9 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC
static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) {
return ge::GRAPH_SUCCESS;
}
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
@@ -170,6 +183,10 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC
uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."),
return ge::GRAPH_FAILED);
auto gmm2WeightDesc = context->GetDynamicInputDesc(INPUT_GMM2_WEIGHT_INDEX, 0);
if (GetPrimaryFormat(gmm2WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) {
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true;
}
if (gmm2ListLen > 1) { // List
OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank,
OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED);
@@ -198,6 +215,9 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC
static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) {
return ge::GRAPH_SUCCESS;
}
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
@@ -383,16 +403,26 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
} else {
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
}
uint32_t wTypeSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? TOKEN_DTYPE_BYTE_SIZE : sizeof(int8_t);
size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
// hbm input = x: float16 or bf16
// buf1 dispatch (Only AIV) => x1: float16 or bf16
// buf2 gmm1 (Only AIC) => y1: float
// sync
// buf3 swiglu (Only AIV) => x2: float16 or bf16
// sync ?
// buf4 gmm2 (AIC & AIV) => y2: float16 or bf16
// hbm combine (Only AIV) => output: float16 or bf16
size_t x1TokenSize = maxTokenNum * h * wTypeSize; // x1: float16 or bf16
size_t x2TokenSize = maxTokenNum * gmm2HLen * wTypeSize; // x2: float16 or bf16
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t tokenScaleSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? 0 : CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
size_t CVSwapBufferSize =
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); // y1: float
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; // y2: float
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
@@ -462,6 +492,10 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) {
tilingKey |= EXEC_FLAG_TENSOR_LIST;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat) {
tilingKey |= EXEC_FLAG_ND_FORMAT;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;