[Feature] DispatchGmmCombineDecode support bf16/float16 gmm1/gmm2 weight and support gmm weight with ND format (#6393)
### What this PR does / why we need it?
1. support ND format gmm weight input.
Before this pr, gmm1_weight and gmm2_weight could only be passed as
input to the DispatchGmmCombineDecode operator in NZ data format. After
the modification, they are allowed to be passed in ND data format.
2. support bf16/float16 gmm weight
The current PR modification enables the DispatchGmmCombineDecode
operator to support non-W8A8 scenarios, allowing gmm1_weight and
gmm2_weight to be passed as float16/bfloat16 which is correspond with
input token data type.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: lih827 <383084552@qq.com>
This commit is contained in:
@@ -19,6 +19,7 @@ add_ops_compile_options(
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
-DASCENDC_DUMP=0
|
||||
${_DISPATCH_GMM_INC_OPTS}
|
||||
)
|
||||
|
||||
|
||||
@@ -18,93 +18,190 @@ public:
|
||||
this->Input("x")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16,
|
||||
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16,
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
|
||||
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
|
||||
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
|
||||
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
|
||||
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm1_permuted_weight")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm1_permuted_weight_scale")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm2_weight")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||
ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("gmm2_weight_scale")
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16})
|
||||
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16,
|
||||
ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT16,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_smooth_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("x_active_mask")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
|
||||
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
|
||||
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
|
||||
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
|
||||
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
|
||||
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("output")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16,
|
||||
ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16,
|
||||
ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16,
|
||||
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("expert_token_nums")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
|
||||
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
|
||||
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
|
||||
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
|
||||
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Attr("group_ep").String();
|
||||
this->Attr("ep_rank_size").Int();
|
||||
this->Attr("ep_rank_id").Int();
|
||||
|
||||
@@ -91,6 +91,16 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC
|
||||
auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0);
|
||||
auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape();
|
||||
uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum();
|
||||
ge::DataType gmm1DataType = gmm1FirstTensorElement->GetDataType();
|
||||
if (gmm1DataType == ge::DT_BF16 || gmm1DataType == ge::DT_FLOAT16) {
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = true;
|
||||
} else {
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = false;
|
||||
}
|
||||
auto gmm1WeightDesc = context->GetDynamicInputDesc(INPUT_GMM1_WEIGHT_INDEX, 0);
|
||||
if (GetPrimaryFormat(gmm1WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) {
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true;
|
||||
}
|
||||
|
||||
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
@@ -129,6 +139,9 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC
|
||||
static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context,
|
||||
DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) {
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
const char *nodeName = context->GetNodeName();
|
||||
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
@@ -170,6 +183,10 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC
|
||||
uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum();
|
||||
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."),
|
||||
return ge::GRAPH_FAILED);
|
||||
auto gmm2WeightDesc = context->GetDynamicInputDesc(INPUT_GMM2_WEIGHT_INDEX, 0);
|
||||
if (GetPrimaryFormat(gmm2WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) {
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true;
|
||||
}
|
||||
if (gmm2ListLen > 1) { // List
|
||||
OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank,
|
||||
OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED);
|
||||
@@ -198,6 +215,9 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC
|
||||
static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context,
|
||||
DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) {
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
const char *nodeName = context->GetNodeName();
|
||||
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
@@ -383,16 +403,26 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
|
||||
} else {
|
||||
maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank);
|
||||
}
|
||||
uint32_t wTypeSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? TOKEN_DTYPE_BYTE_SIZE : sizeof(int8_t);
|
||||
|
||||
size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t);
|
||||
size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t);
|
||||
// hbm input = x: float16 or bf16
|
||||
// buf1 dispatch (Only AIV) => x1: float16 or bf16
|
||||
// buf2 gmm1 (Only AIC) => y1: float
|
||||
// sync
|
||||
// buf3 swiglu (Only AIV) => x2: float16 or bf16
|
||||
// sync ?
|
||||
// buf4 gmm2 (AIC & AIV) => y2: float16 or bf16
|
||||
// hbm combine (Only AIV) => output: float16 or bf16
|
||||
|
||||
size_t x1TokenSize = maxTokenNum * h * wTypeSize; // x1: float16 or bf16
|
||||
size_t x2TokenSize = maxTokenNum * gmm2HLen * wTypeSize; // x2: float16 or bf16
|
||||
size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize;
|
||||
maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE);
|
||||
size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t tokenScaleSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? 0 : CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE);
|
||||
size_t CVSwapBufferSize =
|
||||
CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE);
|
||||
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float);
|
||||
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE;
|
||||
size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); // y1: float
|
||||
size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; // y2: float
|
||||
size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize;
|
||||
maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE);
|
||||
size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE);
|
||||
@@ -462,6 +492,10 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) {
|
||||
tilingKey |= EXEC_FLAG_TENSOR_LIST;
|
||||
}
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat) {
|
||||
tilingKey |= EXEC_FLAG_ND_FORMAT;
|
||||
}
|
||||
|
||||
context->SetTilingKey(tilingKey);
|
||||
context->SetBlockDim(aicNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
|
||||
Reference in New Issue
Block a user