diff --git a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt index 1c7abb24..7039b61f 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt +++ b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt @@ -19,6 +19,7 @@ add_ops_compile_options( OPTIONS --cce-auto-sync=off -Wno-deprecated-declarations -Werror + -DASCENDC_DUMP=0 ${_DISPATCH_GMM_INC_OPTS} ) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 1f991815..838b05b8 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -18,93 +18,190 @@ public: this->Input("x") .ParamType(REQUIRED) .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_ids") .ParamType(REQUIRED) .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, - ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, - ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat( - {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight_scale") .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, - ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat( - {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight_scale") .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16, - ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16}) + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT16, + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_scales") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_smooth_scales") .ParamType(OPTIONAL) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x_active_mask") .ParamType(OPTIONAL) .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, - ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("output") .ParamType(REQUIRED) .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("expert_token_nums") .ParamType(REQUIRED) .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, - ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); this->Attr("ep_rank_size").Int(); this->Attr("ep_rank_id").Int(); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index f4d3f430..847a8381 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -91,6 +91,16 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0); auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape(); uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum(); + ge::DataType gmm1DataType = gmm1FirstTensorElement->GetDataType(); + if (gmm1DataType == ge::DT_BF16 || gmm1DataType == ge::DT_FLOAT16) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = true; + } else { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = false; + } + auto gmm1WeightDesc = context->GetDynamicInputDesc(INPUT_GMM1_WEIGHT_INDEX, 0); + if (GetPrimaryFormat(gmm1WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true; + } OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."), return ge::GRAPH_FAILED); @@ -129,6 +139,9 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) { + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) { + return ge::GRAPH_SUCCESS; + } const char *nodeName = context->GetNodeName(); uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; @@ -170,6 +183,10 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum(); OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."), return ge::GRAPH_FAILED); + auto gmm2WeightDesc = context->GetDynamicInputDesc(INPUT_GMM2_WEIGHT_INDEX, 0); + if (GetPrimaryFormat(gmm2WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true; + } if (gmm2ListLen > 1) { // List OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank, OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED); @@ -198,6 +215,9 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) { + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) { + return ge::GRAPH_SUCCESS; + } const char *nodeName = context->GetNodeName(); uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; @@ -383,16 +403,26 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no } else { maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); } + uint32_t wTypeSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? TOKEN_DTYPE_BYTE_SIZE : sizeof(int8_t); - size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t); - size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t); + // hbm input = x: float16 or bf16 + // buf1 dispatch (Only AIV) => x1: float16 or bf16 + // buf2 gmm1 (Only AIC) => y1: float + // sync + // buf3 swiglu (Only AIV) => x2: float16 or bf16 + // sync ? + // buf4 gmm2 (AIC & AIV) => y2: float16 or bf16 + // hbm combine (Only AIV) => output: float16 or bf16 + + size_t x1TokenSize = maxTokenNum * h * wTypeSize; // x1: float16 or bf16 + size_t x2TokenSize = maxTokenNum * gmm2HLen * wTypeSize; // x2: float16 or bf16 size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE); - size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t tokenScaleSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? 0 : CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); size_t CVSwapBufferSize = CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); - size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); - size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; + size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); // y1: float + size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; // y2: float size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize; maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE); size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); @@ -462,6 +492,10 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) { tilingKey |= EXEC_FLAG_TENSOR_LIST; } + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat) { + tilingKey |= EXEC_FLAG_ND_FORMAT; + } + context->SetTilingKey(tilingKey); context->SetBlockDim(aicNum); return ge::GRAPH_SUCCESS; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index aae344af..3fa6575c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -8,6 +8,7 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "dispatch_gmm_combine_decode.h" +#include "dispatch_gmm_combine_decode_bf16_fp16.h" #include #include "lib/matmul_intf.h" @@ -25,12 +26,28 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); + +#if (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_INT8) if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || - TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { - DispatchGmmCombineDecode< - DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op; + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) || + TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) || + TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) { + DispatchGmmCombineDecodeImpl::DispatchGmmCombineDecode< + DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int8_t, int32_t, false, TILING_KEY_VAR> op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); } +#elif (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_BF16 || ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_FLOAT16) + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) || + TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) || + TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) { + DispatchGmmCombineDecodeBf16Fp16Impl::DispatchGmmCombineDecodeBf16Fp16< + DTYPE_GMM1_PERMUTED_WEIGHT, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, DTYPE_GMM1_PERMUTED_WEIGHT, int32_t, false, TILING_KEY_VAR> op; + op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, + expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); + op.Process(); + } +#endif } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 97aa44ea..264483c5 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -34,7 +34,7 @@ #include "dispatch_gmm_combine_decode_base.h" using namespace Catlass; - +namespace DispatchGmmCombineDecodeImpl { using MmadAtlasA2Custom = Gemm::MmadAtlasA2PreloadAsyncWithCallback CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, - layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, @@ -73,7 +75,8 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using L0TileShape = L0TileShape_; using AType = Gemm::GemmType; - using BType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; using CType = Gemm::GemmType; using BlockMmad = Gemm::Block::BlockMmad; @@ -107,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using ElementGroupList = int64_t; using GemmKernel = typename std::conditional< - (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< @@ -178,7 +181,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun template CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, - layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmWorkspace, void *combiner) @@ -189,7 +194,8 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR using L0TileShape = L0TileShape_; using AType = Gemm::GemmType; - using BType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; using CType = Gemm::GemmType; using BlockMmad = Gemm::Block::BlockMmad; @@ -342,6 +348,16 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( gmm2InputDim_ = gmm1OutputDim_ / 2; } +template +__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) { + if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) { + MatrixCoord mc{k, n}; + return layout::RowMajor::template MakeLayoutInUb(mc); + } else { + return layout::zN::template MakeLayout(k, n); + } +} + template __aicore__ inline void DispatchGmmCombineDecode::Process() { @@ -349,11 +365,11 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_}; layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_}; - layout::zN layoutWeight1 = layout::zN::template MakeLayout(tokenHiddenSize_, gmm1OutputDim_); + auto layoutWeight1 = CreateWeightLayout(tokenHiddenSize_, gmm1OutputDim_); layout::VectorLayout layoutW1Scale{gmm1OutputDim_}; layout::VectorLayout layoutX1Scale{maxTokenNum_}; layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_}; - layout::zN layoutWeight2 = layout::zN::template MakeLayout(gmm2InputDim_, gmm2OutputDim_); + auto layoutWeight2 = CreateWeightLayout(gmm2InputDim_, gmm2OutputDim_); layout::VectorLayout layoutW2Scale{gmm2OutputDim_}; layout::VectorLayout layoutX2Scale{maxTokenNum_}; layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_}; @@ -436,4 +452,5 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut, layoutOutput, gmWorkspace, &combiner); } +} // namespace DispatchGmmCombineDecodeImpl #endif // DISPATCH_GMM_COMBINE_DECODE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h index 4bbbc792..1d7f6aab 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h @@ -12,3 +12,6 @@ #include "block_epilogue_per_token_dequant_swiglu.h" #include "block_epilogue_per_token_dequant.hpp" + +#include "block_epilogue_swiglu_bf16_fp16.h" +#include "block_epilogue_bf16_fp16.hpp" \ No newline at end of file diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp new file mode 100644 index 00000000..83c8300c --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP +#define ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP + +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2Combine; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementRawScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE void AlignUbOffset() + { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), + 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); + } + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) + { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } + + CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) + { + if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } + } + + CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) + { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); + constexpr uint32_t epRankStart = 0; + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } + } + + CATLASS_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; + + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; + + uint32_t expertOffset; + + uint32_t ubListId{0}; + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h new file mode 100644 index 00000000..63ad0e66 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "../tile/tile_stride_muls.h" +#include "../tile/tile_stride_binary.h" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2Swiglu; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0, + "The per token scale granularity for word calculation must be 32 bytes aligned."); + static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts."); + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementRawScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + int32_t eventMTE3MTE2 = 0; + int32_t eventMTE2MTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + eventUbMTE3MTE2List[i] = eventMTE3MTE2++; + eventUbMTE2MTE3List[i] = eventMTE2MTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + AscendC::SetFlag(eventUbMTE3MTE2List[i]); + } + ubDenominatorMxN = resource.ubBuf.template GetBufferByByte(ubOffset); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + AscendC::WaitFlag(eventUbMTE3MTE2List[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (0 == actualBlockShapeMNK.k()) { + return; + } + callback(); + ubListId = 0; + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1); + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = 0; // for 1C1V + uint32_t subblockNum = 1; // for 1C1V + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + if (isLeft) { + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Muls(ubDenominatorMxN, ubC, -1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Div(ubD, ubC, ubDenominatorMxN, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + } else { + AscendC::WaitFlag(eventUbMTE3MTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbMTE2MTE3List[ubListId]); + AscendC::WaitFlag(eventUbMTE2MTE3List[ubListId]); + copyUbToGmD(gmTileD, ubC, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbMTE3MTE2List[ubListId]); + } + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + int32_t eventUbMTE3MTE2List[UB_STAGES]; + int32_t eventUbMTE2MTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubDenominatorMxN; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h index df70c101..567ea210 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h @@ -26,4 +26,17 @@ struct EpilogueAtlasA2PerTokenDequantCombine { static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; }; +template +struct EpilogueAtlasA2Swiglu { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +template +struct EpilogueAtlasA2Combine { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; } // namespace Catlass::Epilogue diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h new file mode 100644 index 00000000..ddb26352 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP +#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP + +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, + void *combiner_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), + combiner(combiner_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + } + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, + layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMMultiStageWorkspace; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMMultiStageWorkspace; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 967e5869..b7fb7623 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -22,51 +22,6 @@ #include "../../../dispatch_gmm_combine_decode_base.h" -constexpr uint32_t STATE_OFFSET = 512; -constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; -constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; -constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; -constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; -constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; -constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; -constexpr uint32_t UB_ALIGN = 32; -constexpr uint32_t TOKEN_EXTRA_SPACE = 512; -constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; -constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; -constexpr int64_t LOOP_TMP_SIZE = 4096; -constexpr int32_t SUB_AIV_NUM = 2; -constexpr int32_t ODD_EVEN_BASE = 2; -constexpr int32_t BUFFER_NUM = 2; -constexpr int32_t GATHER_SECOND_NUM = 2; -constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; -constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant -#define OPT_RANK_OFFSET 512 - -#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) -#define CEIL(x, y) (((x) + (y - 1)) / (y)) -#define UB_BLOCK_SIZE (32) -#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ - (((epRankId == rankId) \ - ? ((GM_ADDR)(winContext_->localWindowsExp)) \ - : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ - dataState * WIN_STATE_OFFSET) -#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ - (((epRankId == rankId) \ - ? ((GM_ADDR)(winContext_->localWindowsIn)) \ - : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ - winDataSizeOffset + rankId * OPT_RANK_OFFSET) -#define TOKEN_FLAG_1 (0x55555555) -#define TOKEN_FLAG_2 (0x33333333) -#define V_TO_C_FLAG_1 (0x03030303) -#define V_TO_C_FLAG_2 (0x05050505) -#define CV_FLAG_INDEX 0 -#define GROUP_ID_INDEX 1 -#define PRE_COUNT_INDEX 2 -#define SELF_COUNT_INDEX 3 -#define TOTAL_COUNT_INDEX 4 -#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX -#define GROUP_INFO_SIZE 32 - namespace Catlass::Gemm::Kernel { template @@ -306,54 +261,6 @@ private: Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; }; -__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) -{ - // flag++, like set flag - AscendC::PipeBarrier(); - AscendC::GlobalTensor global; - global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid( - global); - __asm__ __volatile__(""); - uint8_t value = global.GetValue(0); - global.SetValue(0, value + 1); - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid( - global); - __asm__ __volatile__(""); - AscendC::PipeBarrier(); -} - -__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) -{ - // check flag, like wait flag - AscendC::PipeBarrier(); - AscendC::GlobalTensor global; - global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); - while (true) { - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid(global); - __asm__ __volatile__(""); - uint8_t value = global.GetValue(0); - if (value >= target) { - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid(global); - __asm__ __volatile__(""); - break; - } - } - AscendC::PipeBarrier(); -} - -__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) -{ - row = QUANT_SPACE_FACTOR / column; - row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; -} - template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h new file mode 100644 index 00000000..72702d5a --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h @@ -0,0 +1,1805 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once + +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +#include "../../../dispatch_gmm_combine_decode_base.h" + +namespace Catlass::Gemm::Kernel { + +template +class SwigluPost +{ +public: + using ElementInput = float; + using LayoutInput = layout::RowMajor; + using ElementSwigluScale = float; + using LayoutSwigluScale = layout::VectorLayout; + using ElementOutput = ElementOutput_; + using LayoutOutput = layout::RowMajor; + + using InputType = GemmType; + using OutputType = GemmType; + + using EpilogueTileSwizzle = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + struct Params { + __gm__ ElementInput *ptrInput{nullptr}; + LayoutInput layoutInput; + __gm__ ElementSwigluScale *ptrSwigluScale{nullptr}; + LayoutSwigluScale layoutSwigluScale; + __gm__ ElementOutput *ptrOutput{nullptr}; + LayoutOutput layoutOutput; + uint32_t tileRow; + uint32_t tileColumn; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementInput *ptrInput_, LayoutInput const &layoutInput_, + __gm__ ElementSwigluScale *ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, + __gm__ ElementOutput *ptrOutput_, LayoutOutput const layoutOutput_, const uint32_t tileRow_, + const uint32_t tileColumn_) + : ptrInput(ptrInput_), + layoutInput(layoutInput_), + ptrSwigluScale(ptrSwigluScale_), + layoutSwigluScale(layoutSwigluScale_), + ptrOutput(ptrOutput_), + layoutOutput(layoutOutput_), + tileRow(tileRow_), + tileColumn(tileColumn_) + {} + }; + + CATLASS_DEVICE + SwigluPost(Arch::Resource const &resource, Params const ¶ms_) : params(params_) + { + int64_t ubOffset = 0; + tileRow = params_.tileRow; + tileColumn = params_.tileColumn; + tileCount = tileRow * tileColumn; + halfTileColumn = tileColumn / 2; + halfTileCount = tileRow * halfTileColumn; + + ubInput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementInput); + ubOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementOutput); + + ubInputRightHalf = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(float); + + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + + CATLASS_DEVICE + ~SwigluPost() + { + AscendC::WaitFlag(0); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void operator()(MatrixCoord const &blockShape, MatrixCoord const &blockCoord, MatrixCoord const &actualBlockShape) + { + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmInput; + gmInput.SetGlobalBuffer(params.ptrInput); + AscendC::GlobalTensor gmOutput; + gmOutput.SetGlobalBuffer(params.ptrOutput); + + auto ubTileStride = MakeCoord(static_cast(tileColumn), 1L); + auto ubHalfTileStride = MakeCoord(static_cast(halfTileColumn), 1L); + auto tileShape = MakeCoord(tileRow, tileColumn); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileInput = gmInput[params.layoutInput.GetOffset(tileOffset)]; + auto layoutGmTileInput = params.layoutInput.GetTileLayout(actualTileShape); + + layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(0); + // continue swiglu computing here + copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + copyGmToUbInput(ubInputRightHalf, gmTileInput[params.layoutInput.shape(1) >> 1], layoutUbInput, layoutGmTileInput); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + AscendC::Mul(ubInput, ubInput, ubInputRightHalf, tileCount); + AscendC::PipeBarrier(); + AscendC::WaitFlag(1); + AscendC::Cast(ubOutput, ubInput, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::SetFlag(1); + AscendC::SetFlag(0); + + auto gmTileOutput = gmOutput[params.layoutOutput.GetOffset(tileOffset)]; + auto layoutGmTileOutput = params.layoutOutput.GetTileLayout(actualTileShape); + + LayoutOutput layoutUbOutput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(1); + copyUbToGmOutput(gmTileOutput, ubOutput, layoutGmTileOutput, layoutUbOutput); + AscendC::SetFlag(1); + } + } + +private: + Params params; + uint32_t tileRow; + uint32_t tileColumn; + uint32_t tileCount; + uint32_t halfTileColumn; + uint32_t halfTileCount; + + AscendC::LocalTensor ubInput; + AscendC::LocalTensor ubOutput; + + AscendC::LocalTensor ubInputRightHalf; + + Epilogue::Tile::CopyGm2Ub copyGmToUbInput; + Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; +}; + +template +class GroupedMatmulSliceMSwigluMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using XType = ExpandXType; + using ElementSwigluScale = typename SwigluPost::ElementSwigluScale; + using LayoutSwigluScale = typename SwigluPost::LayoutSwigluScale; + using ElementOutput = typename SwigluPost::ElementOutput; + using LayoutOutput = typename SwigluPost::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + + // Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementSwigluScale *ptrSwigluScale; + LayoutSwigluScale layoutSwigluScale; + GM_ADDR ptrWorkspace; + GM_ADDR gmX; + GM_ADDR debugGm; + GM_ADDR gmexpertIds; + GM_ADDR gmXActiveMask; + + GM_ADDR gmExpandIdx; + GM_ADDR gmEpSendCount; + GM_ADDR gmResvered; + GM_ADDR gmExpertTokenNums; + + uint32_t epRankSize; + uint32_t epRankId; + uint32_t moeExpertNum; + uint32_t moeExpertNumPerRank; + uint32_t sharedExpertNum; + uint32_t sharedExpertRankNum; + uint32_t quantMode; + uint32_t globalBs; + uint32_t bs; + uint32_t topK; + uint32_t tokenLen; + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, GM_ADDR ptrWorkspace_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_, + GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_, + uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, + uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, + uint32_t h) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrSwigluScale(reinterpret_cast<__gm__ ElementSwigluScale *>(ptrSwigluScale_)), + layoutSwigluScale(layoutSwigluScale_), + ptrWorkspace(ptrWorkspace_), + gmX(gmX_), + debugGm(debugGm_), + gmexpertIds(gmexpertIds_), + gmExpandIdx(gmExpandIdx_), + gmEpSendCount(gmEpSendCount_), + gmExpertTokenNums(gmExpertTokenNums_), + gmXActiveMask(gmXActiveMask_), + gmResvered(gmResvered_), + epRankSize(epRankSize_), + epRankId(epRankId_), + moeExpertNum(moeExpertNum_), + moeExpertNumPerRank(moeExpertNumPerRank_), + sharedExpertNum(sharedExpertNum_), + sharedExpertRankNum(sharedExpertRankNum_), + quantMode(quantMode_), + globalBs(globalBs_), + bs(bs_), + topK(topK_), + tokenLen(h) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMSwigluMultiStageWorkspace() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + aicIdx = AscendC::GetBlockIdx(); + subBlockNum = AscendC::GetSubBlockNum(); + aiCoreGroupNum = AscendC::GetBlockNum(); + aicNum = aiCoreGroupNum; + aivNum = aiCoreGroupNum * SUB_AIV_NUM; + aicStateGlobalCoreIdx = aivNum + aicIdx; + moeExpertNumPerRank = params.moeExpertNumPerRank; + isShareExpert = (params.epRankId < params.sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + // when localExpertNum=1, all cores send token and recv token in sequence + recvCoreNum = aivNum; + // when localExpertNum>1, half of cores send token and another half recv token in parallel + if (localExpertNum > 1) { + recvCoreNum = aiCoreGroupNum; + } + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + // state of cv flag + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aicStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * aicNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + AscendC::GlobalTensor groupTokenNumStateTensor; + aicSetFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part + uint32_t target = 1; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + // wait AIV recv needed tokens + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((aicIdx < startCoreIdx) ? (aicIdx + aicNum) : aicIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + aicWaitFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(AscendC::GetBlockIdx()), + target}; // AIC wait for flags in former part + target += 1; + callbackBeforeFixpipe = MakeCallback(&aicWaitFunc1); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFunc1); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * aicNum + aicIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + + startCoreIdx = (startCoreIdx + coreLoops) % aicNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + target += 1; + --stageUsed; + } + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset) + { + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor maskInputTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + AscendC::LocalTensor maskInputInt8Tensor = maskInputTensor.template ReinterpretCast(); + subUbOffset += CEIL_UP(axisBS * sizeof(bool)); + AscendC::LocalTensor maskTmpTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(axisBS * sizeof(half)); + AscendC::LocalTensor sumOutTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + + AscendC::GlobalTensor xActiveMaskGMTensor; + xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask); + uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool)); + + AscendC::DataCopyExtParams maskParams = {1U, static_cast(axisBS * sizeof(bool)), 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams maskCopyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS); + AscendC::PipeBarrier(); + AscendC::SumParams params{1, axisBsAlignSize, axisBS}; + AscendC::Sum(sumOutTensor, maskTmpTensor, params); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + activeMaskBsCnt = static_cast(sumOutTensor.GetValue(0)); + } + + CATLASS_DEVICE + void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) + { + // calculate index in remote + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor dstExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor subExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor workLocalTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::Duplicate(dstExpIdTensor_, dstExpertId, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, tokenIndex); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpFp32 = subExpIdTensor_.ReinterpretCast(); + AscendC::LocalTensor tmpoutFp32 = dstExpIdTensor_.ReinterpretCast(); + AscendC::Abs(tmpoutFp32, tmpFp32, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Mins(subExpIdTensor_, dstExpIdTensor_, 1, tokenIndex); + AscendC::PipeBarrier(); + AscendC::ReduceSum(tmpoutFp32, tmpFp32, workLocalTensor_, tokenIndex); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + int32_t curOtherExpertCnt = dstExpIdTensor_(0); + if (tokenIndex > curOtherExpertCnt) { + curExpertCnt = tokenIndex - curOtherExpertCnt; + } + } + + CATLASS_DEVICE + void CalAndSendTokenCount() + { + uint32_t totalExpertNum = sharedExpertRankNum + moeExpertNum; + uint32_t sendCountExpertNum = totalExpertNum / sendCoreNum; + uint32_t remainderRankNum = totalExpertNum % sendCoreNum; + uint32_t startExpertId = sendCountExpertNum * sendCoreIdx; + if (sendCoreIdx < remainderRankNum) { + sendCountExpertNum += 1; + startExpertId += sendCoreIdx; + } else { + startExpertId += remainderRankNum; + } + uint32_t endExpertId = startExpertId + sendCountExpertNum; + if (startExpertId >= totalExpertNum) { + return; + } + + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(CEIL(expertCntUp, INT32_COUNT_PER_BLOCK) * INT32_COUNT_PER_BLOCK * UB_BLOCK_SIZE); + AscendC::Duplicate(statusTensor_, (int32_t)0, + expertCntUp * INT32_COUNT_PER_BLOCK); + if (state == 0) { + // set the first number of every 8 numbers as 0x3F800000(float 1.0) + uint64_t mask[2] = {0x101010101010101, 0}; + AscendC::PipeBarrier(); + AscendC::Duplicate(statusTensor_, 0x3F800000, mask, CEIL(expertCntUp, 8), 1, 8); + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + if (!isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + for (uint32_t curExpertId = startExpertId; curExpertId < endExpertId; ++curExpertId) { + if (curExpertId < sharedExpertRankNum) { + continue; + } + int32_t curExpertCnt = 0; + int32_t dstExpertId = curExpertId - sharedExpertRankNum; + CalExpandxIdx(dstExpertId, expertIdsCnt, curExpertCnt, ubOffset); + int32_t cntPosIndex = curExpertId * INT32_COUNT_PER_BLOCK + 1; + statusTensor_(cntPosIndex) = curExpertCnt; + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::GlobalTensor rankGMTensor; + uint32_t offset = stateOffset * epRankId; + for (uint32_t rankIndex = startExpertId; rankIndex < endExpertId; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank > 1 && (rankIndex >= sharedExpertRankNum)) { + dstRankId = ((rankIndex - sharedExpertRankNum) / moeExpertNumPerRank + sharedExpertRankNum); + offset = + (epRankId + (rankIndex - sharedExpertRankNum) % moeExpertNumPerRank * epRankSize) * stateOffset; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_STATE_ADDR_BY_RANK_ID(dstRankId) + offset); + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + AscendC::DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); + } + } + + CATLASS_DEVICE + void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) + { + uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; + uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum; + uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum; + uint32_t startTokenId = sendTokenNum * newAivId; + if (newAivId < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= activeMaskBsCnt) { + return; + } + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor xInt32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[0] = xInTensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[1] = xInTensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + AscendC::GlobalTensor dstWinGMTensor; + AscendC::GlobalTensor expandXOutGlobal; + expandXOutGlobal.SetGlobalBuffer((__gm__ XType *)(gmX1)); + + // double buffer + AscendC::SetFlag(0); + AscendC::SetFlag(1); + + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t index = (tokenIndex & 1) ? 0 : 1; + int32_t eventId = (tokenIndex & 1) ? 0 : 1; + uint32_t temp = (epRankId * axisBS) / sharedExpertRankNum; + uint32_t moeOnShareRank = CEIL((tokenIndex + 1 + temp) * sharedExpertRankNum, axisBS) - 1 - epRankId; + uint32_t preCnt = (moeOnShareRank + epRankId) * axisBS / sharedExpertRankNum - + epRankId * axisBS / sharedExpertRankNum; + dstWinGMTensor.SetGlobalBuffer( + (__gm__ XType *)(GET_WIND_ADDR_BY_RANK_ID(moeOnShareRank) + expertPerSizeOnWin * epRankId)); + + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + xInt32Tensor[index](hOutSize / sizeof(int32_t)) = tokenFlag; + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + + if (isShareExpert) { + AscendC::DataCopy(expandXOutGlobal[tokenIndex * tokenLength], xInTensor[index], tokenLength); + } else { + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommuBf16Fp16], xInTensor[index], + tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommuBf16Fp16 + tokenLength], + xInTensor[index][hOutSize / sizeof(XType)], 16); + } + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void SendToMoeExprt(GM_ADDR gmX, GM_ADDR gmExpandIdx) + { + uint32_t sendTokenNum = expertIdsCnt / sendToMoeAivNum; + uint32_t remainderTokenNum = expertIdsCnt % sendToMoeAivNum; + uint32_t startTokenId = sendTokenNum * sendCoreIdx; + if (sendCoreIdx < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += sendCoreIdx; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= expertIdsCnt) { + return; + } + AscendC::LocalTensor expertCountTensor = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + AscendC::Duplicate(expertCountTensor, (int32_t)0, expertIdsCnt); + AscendC::SetFlag(1); + AscendC::WaitFlag(1); + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor xInt32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[0] = xInTensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[1] = xInTensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + AscendC::GlobalTensor dstWinGMTensor; + AscendC::SetFlag(0); + AscendC::SetFlag(1); + uint32_t sendValidTokenIndex = 0; + for (uint32_t sendGroupIndex = 0; sendGroupIndex < moeExpertNumPerRank; ++sendGroupIndex) { + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + // Send to preferentically to the specicied expert + if ((dstExpertId % moeExpertNumPerRank) != sendGroupIndex) { + continue; + } + uint32_t index = (sendValidTokenIndex & 1) ? 0 : 1; + int32_t eventId = (sendValidTokenIndex & 1) ? 0 : 1; + sendValidTokenIndex += 1; + int32_t curExpertCnt = 0; + CalExpandxIdx(dstExpertId, tokenIndex, curExpertCnt, ubOffset); + expertCountTensor(tokenIndex - startTokenId) = curExpertCnt; + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank + sharedExpertRankNum; + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(tempRankId) + + (expertPerSizeOnWin * (epRankId * moeExpertNumPerRank + + dstExpertId % moeExpertNumPerRank)) + + hCommuSize * curExpertCnt); + dstWinGMTensor.SetGlobalBuffer((__gm__ XType *)rankGM); + + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex / axisK * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + xInt32Tensor[index](hOutSize / sizeof(int32_t)) = tokenFlag; + AscendC::SetFlag(eventId); + + AscendC::WaitFlag(eventId); + + AscendC::DataCopy(dstWinGMTensor, xInTensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[tokenLength], xInTensor[index][hOutSize / sizeof(XType)], 16); + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + + AscendC::GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)gmExpandIdx + startTokenId); + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, + 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(expandIdxGMTensor, expertCountTensor, expertIdsCntParams); + } + + CATLASS_DEVICE void + SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask) + { + ubOffset = 0; + if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) { + TokenActiveMaskCal(gmXActiveMask, ubOffset); + } + expertIdsCnt = activeMaskBsCnt * axisK; + + AscendC::GlobalTensor expertIdsGMTensor_; + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); + expertIdsTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + CalAndSendTokenCount(); + AscendC::PipeBarrier(); + if (hasShareExpert) { + sendToShareAivNum = sendCoreNum / (axisK + 1); + if (sendToShareAivNum == 0) { + sendToShareAivNum = 1; + } + } + sendToMoeAivNum = sendCoreNum - sendToShareAivNum; + + AscendC::SetDeqScale((half)1.000000e+00f); + if (hasShareExpert && sendCoreIdx >= sendToMoeAivNum) { + SendToShareExprt(gmX, gmX1, gmX1Scale); + } else { + SendToMoeExprt(gmX, gmExpandIdx); + } + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RecvCount(int64_t ubOffset) + { + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + uint32_t startStatusIndex = 0; // every wait for all token counts + + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor sumTmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + gatherTmpTensor.SetValue(0, 1); + + uint32_t mask = 1; + uint64_t rsvdCnt = 0; + AscendC::SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget * recStatusNumPerCore) + (float)0.5; + AscendC::DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, static_cast(15), + 0}; + AscendC::GlobalTensor windowInstatusFp32Tensor_; + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId)); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + uint32_t preRecvTokenCount = 0; + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset / sizeof(float)], + intriParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumTmpTensor, sumParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + CATLASS_DEVICE + void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) + { + // calculate token index in output tensor + int64_t subUbOffset = ubOffset; + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + if (isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + uint64_t rsvdCnt = 0; + gatherTmpTensor.SetValue(0, GATHER_SECOND_NUM); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + AscendC::PipeBarrier(); + AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, + (startRankId + 1) <= recvExpertNum ? (startRankId + 1) : recvExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + } + + CATLASS_DEVICE + void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t &coreTokenCount, uint32_t startRankId, + uint32_t endRankId, uint32_t recvRankNumPerCore, int64_t ubOffset) + { + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::LocalTensor xTmpTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(axisHCommu * sizeof(XType)); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutCountTensor = (gatherMaskOutTensor.template ReinterpretCast()); + AscendC::GlobalTensor tokGlobal; + AscendC::GlobalTensor tokGlobalInt32; + AscendC::GlobalTensor expandXOutGlobal; + uint32_t beginIdx = 0; + for (uint32_t index = startRankId; index < endRankId; index++) { + uint32_t i = index - startRankId; + if (i > 0) { + gatherMaskOutCountTensor.SetValue( + i, gatherMaskOutCountTensor.GetValue(i - 1) + gatherMaskOutCountTensor.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_COUNT_PER_BLOCK + 1); + coreTokenCount += count; + beginIdx = gatherMaskOutCountTensor.GetValue(i) - count; + if (isShareExpert && index < sharedExpertRankNum) { + beginIdx += count; + continue; + } + uint32_t winOffset = index; + if (!isShareExpert && moeExpertNumPerRank > 1) { + // srcRank: index % epRankSize + // localExpertId: index / epRankSize + // Addr: (srcRank * moeExpertNumPerRank + localExpertId) * expertPerSizeOnWin + winOffset = (index % epRankSize) * moeExpertNumPerRank + index / epRankSize; + } + GM_ADDR wAddr = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(epRankId)) + winOffset * expertPerSizeOnWin; + AscendC::SetFlag(0); + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ XType *)(wAddr + j * hCommuSize)); + tokGlobalInt32.SetGlobalBuffer((__gm__ int32_t *)(wAddr + j * hCommuSize + hOutSize)); + expandXOutGlobal.SetGlobalBuffer((__gm__ XType *)(gmX1) + (beginIdx + j) * tokenLength, tokenLength); + + while (true) { + AscendC::DataCopy(tmpLocalTensor, tokGlobalInt32, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + if (tmpLocalTensor.GetValue(0) == tokenFlag) { + tokGlobalInt32.SetValue(0, 0); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(tokGlobalInt32[1]); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::DataCopy(xTmpTensor_, tokGlobal, tokenLength); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(expandXOutGlobal, xTmpTensor_, tokenLength); + AscendC::SetFlag(0); + } + AscendC::WaitFlag(0); + beginIdx += count; + } + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyExtParams dataCopyOutParams = {1U, static_cast(recvRankNumPerCore * sizeof(int32_t)), 0U, + 0U, 0U}; + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::DataCopyPad(sendCountsGlobal[startRankId], gatherMaskOutCountTensor, dataCopyOutParams); + } + + CATLASS_DEVICE + void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) + { + ubOffset = 0; + RecvCount(ubOffset); + + uint32_t recvExpertNum = isShareExpert ? epRankSize : expertCntUp; + uint32_t recvCoreNumPerGroup = recvCoreNum / localExpertNum; + uint32_t recvRankNumPerCore = epRankSize / recvCoreNumPerGroup; + uint32_t remainderRankNum = epRankSize % recvCoreNumPerGroup; + + uint32_t groupId = recvCoreIdx / recvCoreNumPerGroup; + uint32_t recvCoreIdxInGroup = recvCoreIdx % recvCoreNumPerGroup; + uint32_t startRankIdInGroup = recvRankNumPerCore * recvCoreIdxInGroup; + if (recvCoreIdxInGroup < remainderRankNum) { + recvRankNumPerCore += 1; + startRankIdInGroup += recvCoreIdxInGroup; + } else { + startRankIdInGroup += remainderRankNum; + } + uint32_t endRankIdInGroup = startRankIdInGroup + recvRankNumPerCore; + uint32_t startRankId = epRankSize * groupId + startRankIdInGroup; + uint32_t endRankId = epRankSize * groupId + endRankIdInGroup; + + uint32_t coreTokenCount = 0; + + if (startRankId < recvExpertNum) { + // RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info + GetCumSum(startRankId, recvExpertNum, ubOffset); + RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); + } + + // recv finish, inform AIC + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(0); + ubOffset += CEIL_UP(UB_BLOCK_SIZE); + tmpLocalTensor.SetValue(CV_FLAG_INDEX, vToCFlag); + tmpLocalTensor.SetValue(GROUP_ID_INDEX, groupId); + tmpLocalTensor.SetValue(SELF_COUNT_INDEX, coreTokenCount); + AscendC::SetFlag(0); + + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::WaitFlag(0); + AscendC::SetAtomicAdd(); + AscendC::DataCopy(groupTokenNumStateTensor[groupId * GROUP_INFO_SIZE], tmpLocalTensor, INT32_COUNT_PER_BLOCK); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ ElementPerTokenScale *gmTokenScale, + __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, + LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) + { + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(gmCVSwapBuff)); + auto layoutC = layout::RowMajor{L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES, L1TileShape::N}; + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t target = 1; + uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + AscendC::GlobalTensor groupTokenNumStateTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + // just like AIC + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, n, k}; + LayoutPerTokenScale layoutPerTokenScale = + wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layout::RowMajor{currentM, n}; + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + } + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((compCoreIdx < startCoreIdx) ? (compCoreIdx + aiCoreGroupNum) : compCoreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aiCoreGroupNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * aiCoreGroupNum + aiCoreGroupIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + CheckSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(compCoreNum + compCoreIdx), target); + target += 1; + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + EncreaseSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(compCoreIdx)); + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * n; + + startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; + } + } + // clean + AscendC::PipeBarrier(); + AscendC::GlobalTensor softSyncTensor; + softSyncTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + SOFT_SYNC_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(softSyncTensor[compCoreIdx * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], tmpZeroLocalTensor, + INT32_COUNT_PER_BLOCK); + AscendC::DataCopy(softSyncTensor[(compCoreIdx + compCoreNum) * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], + tmpZeroLocalTensor, INT32_COUNT_PER_BLOCK); + } + + CATLASS_DEVICE + void AivInitParams(Params const ¶ms) + { + aiCoreGroupNum = AscendC::GetBlockNum(); + subBlockNum = AscendC::GetSubBlockNum(); // 1C2V + aicNum = aiCoreGroupNum; + aivNum = aiCoreGroupNum * subBlockNum; + aivIdx = AscendC::GetBlockIdx(); + aiCoreGroupIdx = aivIdx / subBlockNum; + aivStateGlobalCoreIdx = aivNum + aicNum + aivIdx; + + isCompCore = (aivIdx % subBlockNum) == 0; + compCoreNum = aiCoreGroupNum; + compCoreIdx = aiCoreGroupIdx; + // when localExpertNum=1, all cores send token and recv token in sequence + isRecvCore = true; + isSendCore = true; + recvCoreIdx = aivIdx; + sendCoreIdx = aivIdx; + sendCoreNum = aivNum; + recvCoreNum = aivNum; + + moeExpertNumPerRank = params.moeExpertNumPerRank; + + epRankSize = params.epRankSize; + epRankId = params.epRankId; + expertCntUp = epRankSize * moeExpertNumPerRank; + sharedExpertRankNum = params.sharedExpertRankNum; + hasShareExpert = (sharedExpertRankNum > 0); + isShareExpert = (epRankId < sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + moeExpertNum = params.moeExpertNum; + tokenLength = params.tokenLen; + + // when localExpertNum>1, half of cores send token and another half recv token in parallel + if (localExpertNum > 1) { + isRecvCore = ((aivIdx % ODD_EVEN_BASE) == 0); + isSendCore = ((aivIdx % ODD_EVEN_BASE) == 1); + recvCoreIdx = aivIdx / subBlockNum; + sendCoreIdx = aivIdx / subBlockNum; + sendCoreNum = aiCoreGroupNum; + recvCoreNum = aiCoreGroupNum; + } + + hOutSize = tokenLength * sizeof(XType); + scaleParamPad = TOKEN_EXTRA_SPACE; + hCommuSize = hOutSize + scaleParamPad; + axisHCommu = hCommuSize / sizeof(int8_t); + axisHCommuBf16Fp16 = hCommuSize / sizeof(XType); + axisBS = params.bs; + activeMaskBsCnt = axisBS; + axisK = params.topK; + uint32_t maxAxisBs = params.globalBs / epRankSize; + + stateOffset = STATE_OFFSET; + expertPerSizeOnWin = maxAxisBs * hCommuSize; + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + } + + CATLASS_DEVICE + void AivInitState() + { + // state of data sapce + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + dataState = selfDataStatusTensor(aivIdx * UB_ALIGN); + if (dataState == 0) { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + // state of cv flag + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + AscendC::PipeBarrier(); + winDataSizeOffset = dataState * epRankSize * expertPerSizeOnWin * moeExpertNumPerRank; + GM_ADDR statusSpaceGm_ = GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId); + AscendC::GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + state = selfStatusTensor(aivIdx * UB_ALIGN); + if (state == 0) { + sumTarget = (float)1.0; + tokenFlag = TOKEN_FLAG_1; + selfStatusTensor(aivIdx * UB_ALIGN) = 0x3F800000; + } else { + sumTarget = 0.0; + tokenFlag = TOKEN_FLAG_2; + selfStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + CATLASS_DEVICE + void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums) + { + if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { + // clean + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, GROUP_INFO_SIZE * localExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(groupTokenNumStateTensor, tmpZeroLocalTensor, GROUP_INFO_SIZE * localExpertNum); + } + + if (isRecvCore && recvCoreIdx == (recvCoreNum - 1)) { + // record token count for each local expert + AscendC::GlobalTensor expertTokenNumsOutGMTensor_; + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::GlobalTensor nonCumSumExpertTokenNumsTensor; + nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums); + uint32_t tmpTokenNum = 0; + for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); + __asm__ __volatile__(""); + + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum; + nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum); + tmpTokenNum = tokenNum; + + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + __asm__ __volatile__(""); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + nonCumSumExpertTokenNumsTensor[localMoeIndex]); + __asm__ __volatile__(""); + } + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AivInitParams(params); + AivInitState(); + if (isSendCore) { + SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask); + } + if (isRecvCore) { + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); + } + + auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES * L1TileShape::N)); + if (isCompCore) { + CompCoreFunc(params.ptrWorkspace, params.ptrScale, params.ptrPerTokenScale, gmSwigluOutput, + params.problemShape.n(), params.problemShape.k(), params.layoutScale, params.layoutPerTokenScale, + params.layoutOutput); + } + + icache_preload(8); + AscendC::SyncAll(); + AscendC::PipeBarrier(); + + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums); + { + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.gmEpSendCount)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(sendCountsGlobal); + __asm__ __volatile__(""); + totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); + AscendC::PipeBarrier(); + uint32_t n = params.problemShape.n(); + uint32_t nOut = params.problemShape.n() / 2; + uint32_t swigluRowOnce = 0; + CalQuantRow(nOut, swigluRowOnce); + auto swigluLayout = layout::RowMajor{totalTokenCount, n}; + typename SwigluPost::Params swigluParams{ + gmSwigluOutput, swigluLayout, params.ptrSwigluScale, params.layoutSwigluScale, + params.ptrOutput, params.layoutOutput, swigluRowOnce, nOut}; + + SwigluPost blockSwiglu(resource, swigluParams); + MatrixCoord swigluShape(totalTokenCount, nOut); + MatrixCoord swigluBlockShape((uint16_t)(subBlockNum * swigluRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle swigluSwizzle(swigluShape, swigluBlockShape); + for (uint32_t loopIdx = aiCoreGroupIdx; loopIdx < swigluSwizzle.GetLoops(); loopIdx += aiCoreGroupNum) { + auto blockCoord = swigluSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = swigluSwizzle.GetActualTileShape(blockCoord); + blockSwiglu(swigluBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc1; + friend struct AicSetFunc1; + + struct AicWaitFunc1 { + CATLASS_DEVICE + AicWaitFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + CheckSyncFlag(flagAddr, idx, target); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + uint32_t target; + }; + + struct AicSetFunc1 { + CATLASS_DEVICE + AicSetFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + EncreaseSyncFlag(flagAddr, idx); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + }; + + AicWaitFunc1 aicWaitFunc1; + AicSetFunc1 aicSetFunc1; + Arch::Resource resource; + + AscendC::LocalTensor expertIdsTensor_; + + // rank and expert info + uint32_t epRankSize{0}; + uint32_t epRankId{0}; + bool hasShareExpert{false}; + bool isShareExpert{false}; + uint32_t expertCntUp{0}; + uint32_t localExpertNum{0}; + uint32_t sharedExpertRankNum{0}; + uint32_t moeExpertNumPerRank{0}; + uint32_t moeExpertNum{0}; + + // token info + uint32_t hOutSize{0}; + uint32_t scaleParamPad{0}; + uint32_t hCommuSize{0}; + uint32_t axisHCommu{0}; + uint32_t axisHCommuBf16Fp16{0}; + uint32_t axisBS{0}; + uint32_t activeMaskBsCnt{0}; + uint32_t axisK{0}; + uint32_t totalTokenCount{0}; + uint32_t expertIdsCnt{0}; + uint32_t tokenLength{0}; + + // state info + int32_t tokenFlag{0}; // token flag + int32_t vToCFlag{0}; // cv flag, decided by cvDataState + int32_t dataState{0}; // data space state + int32_t cvDataState{0}; // cv flag state + int32_t state{0}; // count flag state + float sumTarget{0.0}; + + // memory info + __gm__ HcclOpResParam *winContext_; + GM_ADDR statusDataSpaceGm; + uint32_t stateOffset{0}; + uint64_t expertPerSizeOnWin{0}; + uint64_t winDataSizeOffset{0}; + + int64_t ubOffset; + + // core info + bool isSendCore{false}; + bool isRecvCore{false}; + bool isCompCore{false}; // calculate deq_swiglu + uint32_t aiCoreGroupNum{0}; + uint32_t aiCoreGroupIdx{0}; + uint32_t subBlockNum{0}; + uint32_t aicNum{0}; + uint32_t aivNum{0}; + uint32_t sendCoreNum{0}; + uint32_t recvCoreNum{0}; + uint32_t compCoreNum{0}; + uint32_t aivIdx{0}; + uint32_t aicIdx{0}; + uint32_t sendCoreIdx{0}; + uint32_t recvCoreIdx{0}; + uint32_t compCoreIdx{0}; + uint32_t aivStateGlobalCoreIdx{0}; + uint32_t aicStateGlobalCoreIdx{0}; + uint32_t sendToMoeAivNum{0}; + uint32_t sendToShareAivNum{0}; +}; + +} // namespace Catlass::Gemm::Kernel + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using XType = ExpandXType; + using ElementSwigluScale = typename SwigluPost::ElementSwigluScale; + using LayoutSwigluScale = typename SwigluPost::LayoutSwigluScale; + using ElementOutput = typename SwigluPost::ElementOutput; + using LayoutOutput = typename SwigluPost::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementSwigluScale *ptrSwigluScale; + LayoutSwigluScale layoutSwigluScale; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrSwigluScale(reinterpret_cast<__gm__ ElementSwigluScale *>(ptrSwigluScale_)), + layoutSwigluScale(layoutSwigluScale_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + auto ptrD = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); + + uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t n = params.problemShape.n(); + uint32_t nOut = params.problemShape.n() / 2; + + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layout::RowMajor{currentM, n}; + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * n; + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + { + uint32_t swigluRowOnce = 0; + CalQuantRow(nOut, swigluRowOnce); + auto swigluLayout = layout::RowMajor{mActual, n}; + typename SwigluPost::Params swigluParams{ptrD, + swigluLayout, + params.ptrSwigluScale, + params.layoutSwigluScale, + params.ptrOutput, + params.layoutOutput, + swigluRowOnce, + nOut}; + + SwigluPost blockSwiglu(resource, swigluParams); + MatrixCoord swigluShape(mActual, nOut); + MatrixCoord swigluBlockShape((uint16_t)(AscendC::GetSubBlockNum() * swigluRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle swigluSwizzle(swigluShape, swigluBlockShape); + for (uint32_t loopIdx = coreIdx; loopIdx < swigluSwizzle.GetLoops(); loopIdx += coreNum) { + auto blockCoord = swigluSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = swigluSwizzle.GetActualTileShape(blockCoord); + + blockSwiglu(swigluBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h index c8fa2e1f..c80678e4 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h @@ -9,7 +9,9 @@ */ #ifndef CAM_MOE_DISTRIBUTE_COMBINE_H #define CAM_MOE_DISTRIBUTE_COMBINE_H +#ifndef OPT_RANK_OFFSET #define OPT_RANK_OFFSET 512 +#endif #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h index f73e2c60..1cc430bd 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h @@ -10,7 +10,9 @@ #ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H #define CAM_MOE_DISTRIBUTE_DISPATCH_H +#ifndef OPT_RANK_OFFSET #define OPT_RANK_OFFSET 512 +#endif #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index b9ac8932..cd4dd6b9 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -12,10 +12,107 @@ #include "../common/moe_distribute_base.h" -#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG -#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG +#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename WType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, WType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG #define TemplateDispatchTypeClass \ typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ bool IsNeedAllgater, uint32_t EXEC_FLAG #define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG + +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; +constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t TOKEN_EXTRA_SPACE = 512; +constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; +constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; +constexpr int64_t LOOP_TMP_SIZE = 4096; +constexpr int32_t SUB_AIV_NUM = 2; +constexpr int32_t ODD_EVEN_BASE = 2; +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t GATHER_SECOND_NUM = 2; +constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; +constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant +#ifndef OPT_RANK_OFFSET +#define OPT_RANK_OFFSET 512 +#endif + +#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) +#define CEIL(x, y) (((x) + (y - 1)) / (y)) +#define UB_BLOCK_SIZE (32) +#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsExp)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ + dataState * WIN_STATE_OFFSET) +#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsIn)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ + winDataSizeOffset + rankId * OPT_RANK_OFFSET) +#define TOKEN_FLAG_1 (0x55555555) +#define TOKEN_FLAG_2 (0x33333333) +#define V_TO_C_FLAG_1 (0x03030303) +#define V_TO_C_FLAG_2 (0x05050505) +#define CV_FLAG_INDEX 0 +#define GROUP_ID_INDEX 1 +#define PRE_COUNT_INDEX 2 +#define SELF_COUNT_INDEX 3 +#define TOTAL_COUNT_INDEX 4 +#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX +#define GROUP_INFO_SIZE 32 + +__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) +{ + // flag++, like set flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + global.SetValue(0, value + 1); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) +{ + // check flag, like wait flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + if (value >= target) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) +{ + row = QUANT_SPACE_FACTOR / column; + row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; +} + + #endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h new file mode 100644 index 00000000..28d54759 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H +#define DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H + +#include "lib/matmul_intf.h" +#include + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/epilogue/tile/tile_broadcast_mul.hpp" +#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h" +#include "catlass/gemm/gemm_type.hpp" +#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h" +#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h" + +#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h" + +#include "dispatch_gmm_combine_decode_tiling.h" +#include "dispatch_gmm_combine_decode_base.h" + +using namespace Catlass; + +namespace DispatchGmmCombineDecodeBf16Fp16Impl { + +using MmadAtlasA2Custom = + Gemm::MmadAtlasA2PreloadAsyncWithCallback; + +using Gmm1L1TileShape = GemmShape; +using Gmm1L0TileShape = GemmShape; +using Gmm1EpilogueTileShape = MatrixShape; +using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; + +using Gmm2L1TileShape = GemmShape; +using Gmm2L0TileShape = GemmShape; +using Gmm2EpilogueTileShape = MatrixShape; +using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; +using Gmm2DispatchPolicy = + Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA; + +template +CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, + GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, + GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums, + uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, + uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, + uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Swiglu; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + + using GemmKernel = typename std::conditional< + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0, + Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0) { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace, + gmX, + debugGm, + gmexpertIds, + gmExpandIdx, + gmEpSendCount, + xActiveMask, + gmResvered, + gmExpertTokenNums, + epRankSize, + epRankId, + moeExpertNum, + moeExpertNumPerRank, + sharedExpertNum, + sharedExpertRankNum, + quantMode, + globalBs, + bs, + topK, + tokenLen}; + // call a kernel + GemmKernel gemm; + gemm(params); + } else { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace}; + // call a kernel + GemmKernel gemm; + gemm(params); + } +} + +template +CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmWorkspace, void *combiner) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Combine; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + typename GemmKernel::Params params{ + problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale, + layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner}; + + // call a kernel + GemmKernel gemm; + gemm(params); +} + +template +class DispatchGmmCombineDecodeBf16Fp16 +{ +public: + __aicore__ inline DispatchGmmCombineDecodeBf16Fp16(){}; + __aicore__ inline void Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, + // output + GM_ADDR output, GM_ADDR expertTokenNums, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + GM_ADDR gmX_; + GM_ADDR gmexpertIds_; + GM_ADDR gmPermuteWeight1_; + GM_ADDR gmPermuteScale1_; + GM_ADDR gmWeight2_; + GM_ADDR gmScale2_; + GM_ADDR gmOutput_; + GM_ADDR gmExpertTokenNums_; + GM_ADDR workspaceGM_; + GM_ADDR gmSmoothScales_; + GM_ADDR gmexpertScales_; + GM_ADDR xActiveMask_; + + uint32_t maxTokenNum_{0}; + uint32_t gmm1OutputDim_{0}; + uint32_t tokenHiddenSize_{0}; + uint32_t groupCount_{0}; + uint32_t gmm2OutputDim_{0}; + uint32_t gmm2InputDim_{0}; + uint32_t globalRankId_{0}; + uint32_t winSizePerRank_{0}; + uint32_t blockDim_{0}; + uint32_t epRankSize_{0}; + uint32_t epRankId_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertNumPerRank_{0}; + uint32_t sharedExpertNum_{0}; + uint32_t sharedExpertRankNum_{0}; + uint32_t quantMode_{0}; + uint32_t globalBs_{0}; + uint32_t bs_{0}; + uint32_t maxBs_{0}; + uint32_t topK_{0}; + + AscendC::TPipe *tpipe_{nullptr}; + __gm__ HcclOpResParam *winContext_{nullptr}; + const DispatchGmmCombineDecodeTilingData *tilingData_; +}; + +template +__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16::Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, + GM_ADDR x_active_mask, + // output + GM_ADDR output, GM_ADDR expertTokenNums, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) +{ + tpipe_ = pipe; + blockDim_ = AscendC::GetBlockNum(); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + gmSmoothScales_ = expert_smooth_scales; // not used now + gmX_ = x; // input token + gmexpertIds_ = expert_ids; + gmPermuteWeight1_ = gmm1_permuted_weight; + gmPermuteScale1_ = nullptr; + gmWeight2_ = gmm2_weight; + gmScale2_ = nullptr; + gmOutput_ = output; + gmExpertTokenNums_ = expertTokenNums; + workspaceGM_ = workspaceGM; + gmexpertScales_ = expert_scales; + xActiveMask_ = x_active_mask; + tilingData_ = tilingData; + epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode; + globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + maxBs_ = globalBs_ / epRankSize_; + + bool isShareExpert = (epRankId_ < sharedExpertRankNum_); + if (isShareExpert) { + maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; + } else { + maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); + } + + gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + gmm2OutputDim_ = tokenHiddenSize_; + gmm2InputDim_ = gmm1OutputDim_ / 2; +} + +template +__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) { + if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) { + MatrixCoord mc{k, n}; + return layout::RowMajor::template MakeLayoutInUb(mc); + } else { + return layout::zN::template MakeLayout(k, n); + } +} + +template +__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16::Process() +{ + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_}; + GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_}; + + layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_}; + auto layoutWeight1 = CreateWeightLayout(tokenHiddenSize_, gmm1OutputDim_); + layout::VectorLayout layoutW1Scale{gmm1OutputDim_}; + layout::VectorLayout layoutX1Scale{maxTokenNum_}; + layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_}; + auto layoutWeight2 = CreateWeightLayout(gmm2InputDim_, gmm2OutputDim_); + layout::VectorLayout layoutW2Scale{gmm2OutputDim_}; + layout::VectorLayout layoutX2Scale{maxTokenNum_}; + layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_}; + + size_t workspaceOffset = 0; + constexpr int32_t resveredWorkSpaceSize = 256 * 1024; + int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType); + int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(ExpandXType); + int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; + GM_ADDR gmX1 = workspaceGM_ + workspaceOffset; + GM_ADDR gmX2 = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxTokenSize); + GM_ADDR gmX1Scale = nullptr; + GM_ADDR gmX2Scale = nullptr; + GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; + GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(blockDim_) * (FP16_BF16_L1M * FP16_BF16_L1N) * + WORKSPACE_STAGES * sizeof(float)); + int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float); + int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType); + int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize; + GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; + GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxSwigluGmm2Size); + GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(groupCount_) * sizeof(int64_t)); + GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(bs_) * topK_ * sizeof(int32_t)); + GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(epRankSize_) * groupCount_ * sizeof(int32_t)); + GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(resveredWorkSpaceSize); + + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + if constexpr (g_coreType == AscendC::AIV) { + AscendC::TPipe tpipe; + MoeDistributeDispatchImpl::CamMoeDistributeDispatch + dispatcher; + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, + gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_); + dispatcher.Process(); + tpipe.Destroy(); + icache_preload(8); + } + + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + } + GmmDeqSwigluQuant( + gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, + layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered, + gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + + MoeDistributeCombineImpl::CamMoeDistributeCombine combiner; + if (g_coreType == AscendC::AIV) { + combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_, + workspaceGM_, nullptr, tilingData_); + } + GmmDeq(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, + gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut, + layoutOutput, gmWorkspace, &combiner); +} +} // namespace DispatchGmmCombineDecodeBf16Fp16Impl +#endif // DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index 3874ffa2..b006d9d7 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -31,6 +31,8 @@ struct DispatchGmmCombineDecodeInfo { uint64_t totalWinSize; uint64_t gmm1HLen; bool isTensorList; + bool isBf16Fp16W; + bool isNDFormat; }; struct DispatchGmmCombineDecodeTilingData { @@ -48,6 +50,8 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; +constexpr uint32_t FP16_BF16_L1M = 128; +constexpr uint32_t FP16_BF16_L1N = 128; constexpr uint32_t GMM1_L1M = 256; constexpr uint32_t GMM1_L1N = 128; constexpr uint32_t GMM1_L1K = 512; @@ -56,6 +60,8 @@ constexpr uint32_t GMM1_EPIM = 64; constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3; constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0; +constexpr uint32_t FP16_BF16_GMM2_L1M = 64; +constexpr uint32_t FP16_BF16_GMM2_L1N = 128; constexpr uint32_t GMM2_L1A_STAGES = 4; constexpr uint32_t GMM2_L1B_STAGES = 2; constexpr uint32_t GMM2_L0A_STAGES = 4; @@ -73,5 +79,6 @@ constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1); constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2); +constexpr uint32_t EXEC_FLAG_ND_FORMAT = (1U << 3); #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py index b928ad7e..278e58c0 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py @@ -4,6 +4,7 @@ import sys from pathlib import Path import numpy as np +import time import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -28,7 +29,9 @@ BASE_KWARGS = { "enable_dynamic_bs": False, "test_graph": False, "with_mc2_mask": False, - "dynamic_eplb": False + "dynamic_eplb": False, + "w8a8_dynamic": True, + "is_nz": True } @@ -50,7 +53,7 @@ def permute_weight(w: torch.Tensor, tile_n): def output_to_file(rank_id): - return False + return rank_id > 0 class DecodeMoeOps(torch.nn.Module): @@ -68,8 +71,14 @@ class DecodeMoeOps(torch.nn.Module): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__() + if w8a8_dynamic: + assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic" + else: + assert (gmm1_weight_scale is None and gmm2_weight_scale is None), "gmm1_weight_scale and gmm2_weight_scale must be None for w8a8_dynamic" self.ep_hcomm_info = ep_hcomm_info self.batch_size = batch_size self.token_hidden_size = token_hidden_size @@ -84,38 +93,47 @@ class DecodeMoeOps(torch.nn.Module): self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank self.ep_recv_count_size = self.local_expert_num * ep_world_size self.dynamic_eplb = dynamic_eplb + self.w8a8_dynamic = w8a8_dynamic + self.is_nz = is_nz self.gmm1_weight = torch.empty([ self.local_expert_num, self.token_hidden_size, self.moe_intermediate_size * 2 ]) - self.gmm1_weight_scale = torch.empty( - [self.local_expert_num, self.moe_intermediate_size * 2]) self.gmm2_weight = torch.empty([ self.local_expert_num, self.moe_intermediate_size, self.token_hidden_size ]) - self.gmm2_weight_scale = torch.empty( - [self.local_expert_num, self.token_hidden_size]) + if self.w8a8_dynamic: + self.gmm1_weight_scale = torch.empty( + [self.local_expert_num, self.moe_intermediate_size * 2]) + self.gmm2_weight_scale = torch.empty( + [self.local_expert_num, self.token_hidden_size]) + else: + self.gmm1_weight_scale = None + self.gmm2_weight_scale = None + self.gmm1_weight_scale_fp32 = None + self.gmm2_weight_scale_fp32 = None self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale) def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, - torch_npu.Format.FRACTAL_NZ) + if self.w8a8_dynamic: + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) - self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, - requires_grad=False) self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) - self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, - requires_grad=False) - - self.gmm1_weight_scale_fp32 = torch.nn.Parameter( - gmm1_weight_scale.float(), requires_grad=False) - self.gmm2_weight_scale_fp32 = torch.nn.Parameter( - gmm2_weight_scale.float(), requires_grad=False) + if self.w8a8_dynamic: + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, + requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, + requires_grad=False) + self.gmm1_weight_scale_fp32 = torch.nn.Parameter( + gmm1_weight_scale.float(), requires_grad=False) + self.gmm2_weight_scale_fp32 = torch.nn.Parameter( + gmm2_weight_scale.float(), requires_grad=False) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): @@ -142,12 +160,15 @@ class SmallOps(DecodeMoeOps): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num, dynamic_eplb) + shared_expert_rank_num, dynamic_eplb, w8a8_dynamic, + is_nz) self.tp_hcomm_info = "" def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, @@ -167,7 +188,7 @@ class SmallOps(DecodeMoeOps): expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=self.shared_expert_rank_num, - quant_mode=2, + quant_mode=2 if self.w8a8_dynamic else 0, global_bs=self.batch_size * self.ep_world_size, expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 ) @@ -181,22 +202,26 @@ class SmallOps(DecodeMoeOps): group_list_type=1, # 默认为0,代表前缀和形式 group_type=0, # 0代表m轴分组 group_list=expert_token_nums, - output_dtype=torch.int32)[0] - y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( - x=y1_int32, - weight_scale=self.gmm1_weight_scale.to(torch.float32), - activation_scale=dynamic_scales, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_token_nums, - activate_left=True, - quant_mode=1, - ) + output_dtype=torch.int32 if self.w8a8_dynamic else output_dtype)[0] + y1_scale = None + if self.w8a8_dynamic: + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=self.gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + else: + y1 = torch_npu.npu_swiglu(y1_int32) y2 = torch_npu.npu_grouped_matmul(x=[y1], weight=[self.gmm2_weight], - scale=[self.gmm2_weight_scale], - per_token_scale=[y1_scale], + scale=[self.gmm2_weight_scale] if self.w8a8_dynamic else None, + per_token_scale=[y1_scale] if self.w8a8_dynamic else None, split_item=2, group_list_type=1, group_type=0, @@ -240,15 +265,19 @@ class FusionOp(DecodeMoeOps): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num, dynamic_eplb) + shared_expert_rank_num, dynamic_eplb, w8a8_dynamic, + is_nz) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): + smooth_scales = torch.zeros(128 * 1024 * 1024).npu() output = torch.ops._C_ascend.dispatch_gmm_combine_decode( x=x, expert_ids=expert_ids, @@ -271,29 +300,35 @@ class FusionOp(DecodeMoeOps): def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, - torch_npu.Format.FRACTAL_NZ) + if self.is_nz: + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) if self.dynamic_eplb: self.gmm1_weight = [ weight.clone() for weight in gmm1_weight.unbind(dim=0) ] - self.gmm1_weight_scale_fp32 = [ - weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) - ] self.gmm2_weight = [ weight.clone() for weight in gmm2_weight.unbind(dim=0) ] - self.gmm2_weight_scale_fp32 = [ - weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) - ] + if self.w8a8_dynamic: + self.gmm1_weight_scale_fp32 = [ + weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) + ] + self.gmm2_weight_scale_fp32 = [ + weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) + ] else: self.gmm1_weight = [gmm1_weight.clone()] - self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] self.gmm2_weight = [gmm2_weight.clone()] - self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + if self.w8a8_dynamic: + self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] + self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + else: + self.gmm1_weight_scale_fp32 = [torch.ones(1).npu().to(gmm1_weight.dtype)] + self.gmm2_weight_scale_fp32 = [torch.ones(1).npu().to(gmm2_weight.dtype)] def generate_datas(batch_size, @@ -306,7 +341,8 @@ def generate_datas(batch_size, top_k=8, test_bfloat16=True, enable_dynamic_bs=False, - with_mc2_mask=False): + with_mc2_mask=False, + w8a8_dynamic=True): is_shared_expert = global_rank_id < shared_expert_rank_num moe_expert_num_per_rank = moe_expert_num // (ep_world_size - shared_expert_rank_num) @@ -318,41 +354,59 @@ def generate_datas(batch_size, gmm1_output_dim = moe_intermediate_size * 2 gmm2_input_dim = moe_intermediate_size gmm2_output_dim = token_hidden_size - x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5 + x = torch.rand([actual_bs, token_hidden_size]) * 0.5 - 0.5 expert_ids = torch.arange( global_rank_id * batch_size * top_k, global_rank_id * batch_size * top_k + actual_bs * top_k).to( torch.int32).view(actual_bs, top_k) expert_ids = expert_ids % moe_expert_num - if is_shared_expert: - gmm1_weight = torch.ones([ - local_expert_num, gmm1_input_dim, gmm1_output_dim - ]).to(torch.int8) * 4 - gmm2_weight = torch.ones([ - local_expert_num, gmm2_input_dim, gmm2_output_dim - ]).to(torch.int8) * 4 - gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 - gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 - gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim - ]) * 0.0015 - gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim - ]) * 0.0015 + gmm1_weight_scale = None + gmm2_weight_scale = None + if w8a8_dynamic: + if is_shared_expert: + gmm1_weight = torch.ones([ + local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.int8) * 4 + gmm2_weight = torch.ones([ + local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.int8) * 4 + gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 + gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim + ]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim + ]) * 0.0015 + else: + gmm1_weight = torch.randint( + -16, 16, + [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) + gmm2_weight = torch.randint( + -16, 16, + [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim + ]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim + ]) * 0.003 + 0.0015 else: - gmm1_weight = torch.randint( - -16, 16, - [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) - gmm2_weight = torch.randint( - -16, 16, - [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) - gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim - ]) * 0.003 + 0.0015 - gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim - ]) * 0.003 + 0.0015 + if is_shared_expert: + gmm1_weight = torch.ones([ + local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5 + gmm2_weight = torch.ones([ + local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5 + else: + gmm1_weight = torch.rand([local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25 + gmm2_weight = torch.rand([local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25 + gmm1_weight[:, ::2, :] = gmm1_weight[:, ::2, :] * -1 + gmm2_weight[:, ::2, :] = gmm2_weight[:, ::2, :] * -1 expert_scales = torch.rand(actual_bs, top_k) if test_bfloat16: x = x.bfloat16() - gmm1_weight_scale = gmm1_weight_scale.bfloat16() - gmm2_weight_scale = gmm2_weight_scale.bfloat16() + if w8a8_dynamic: + assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic" + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() else: x = x.half() smooth_sales = None @@ -380,7 +434,9 @@ def run_once(local_rank_id, enable_dynamic_bs=False, test_graph=False, with_mc2_mask=False, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None global_rank_id = local_rank_id # 单机 @@ -407,7 +463,7 @@ def run_once(local_rank_id, ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) input_datas, weight_datas, actual_bs, valid_token_num = generate_datas( - *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask) + *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask, w8a8_dynamic) input_datas = [ data.npu() if data is not None else None for data in input_datas ] @@ -415,27 +471,52 @@ def run_once(local_rank_id, data.npu() if data is not None else None for data in weight_datas ] small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter, - dynamic_eplb).npu() # type: ignore + dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter, - dynamic_eplb).npu() # type: ignore + dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore if test_graph: config = torchair.CompilerConfig() config.mode = "reduce-overhead" npu_backend = torchair.get_npu_backend(compiler_config=config) fused_ops = torch.compile(fused_ops, backend=npu_backend) + + # test performance + start_time = time.perf_counter() + for _ in range(100): + small_op_token_output, small_op_count_output = small_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + elapsed_time_us = elapsed_time * 1000000 + print(f"rank-{global_rank_id} small {elapsed_time_us} us") + start_time = time.perf_counter() + for _ in range(100): + fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + elapsed_time_us = elapsed_time * 1000000 + print(f"rank-{global_rank_id} fused {elapsed_time_us} us") small_op_token_output, small_op_count_output = small_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + print(f"rank-{global_rank_id} Small op End") fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) torch_npu.npu.synchronize(device_id) + print(f"rank-{global_rank_id} Fused op End") dist.destroy_process_group() if log_file is not None: log_file.close() - - torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), - fused_op_token_output[0:valid_token_num].cpu(), - atol=2.0, - rtol=0.02) - torch.testing.assert_close(small_op_count_output.cpu(), - fused_op_count_output.cpu()) + try: + torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), + fused_op_token_output[0:valid_token_num].cpu(), + atol=2.0, + rtol=0.02) + torch.testing.assert_close(small_op_count_output.cpu(), + fused_op_count_output.cpu()) + except Exception as e: + print(f"rank-{global_rank_id} Assert close Failed: {e}") + else: + print(f"rank-{global_rank_id} Assert close Pass") gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() @@ -444,9 +525,16 @@ def run_once(local_rank_id, @torch.inference_mode() def test_dispatch_gmm_combine_decode_base(): custom_kwargs = BASE_KWARGS + custom_kwargs["batch_size"] = 32 + custom_kwargs["ep_world_size"] = 8 + custom_kwargs["moe_expert_num"] = 32 + custom_kwargs["w8a8_dynamic"] = False + custom_kwargs["is_nz"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) + print(f"{custom_kwargs=}") mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + print(f"{custom_kwargs=}") @torch.inference_mode() @@ -465,3 +553,6 @@ def test_dispatch_gmm_combine_decode_dynamic_eplb(): ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + +if __name__ == "__main__": + test_dispatch_gmm_combine_decode_base()