diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp index 309e179a..2f9efc9b 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp @@ -26,10 +26,10 @@ extern "C" { extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -42,7 +42,7 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor); extern aclnnStatus aclnnInnerDispatchGmmCombineDecode( @@ -54,10 +54,10 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecode( aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -70,14 +70,14 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor) { return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, - output, epRecvCount, workspaceSize, executor); + output, expertTokenNums, workspaceSize, executor); } aclnnStatus aclnnDispatchGmmCombineDecode( diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h index 6601e2a1..d632e8c9 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h @@ -19,10 +19,10 @@ extern "C" { __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -35,7 +35,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 520c64c4..511f653c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -26,22 +26,22 @@ public: .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm1_permuted_weight_scale") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm2_weight_scale") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); @@ -65,9 +65,9 @@ public: .DataType({ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("ep_recv_count") + this->Output("expert_token_nums") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) + .DataType({ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp index a0f1d219..7267f2b2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp @@ -18,7 +18,7 @@ namespace ge { constexpr uint32_t EXPAND_X_INDEX = 0; constexpr uint32_t EXPERT_IDS_INDEX = 1; constexpr uint32_t OUTPUT_X_INDEX = 0; -constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1; +constexpr uint32_t OUTPUT_EXPERT_TOKEN_NUMS = 1; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; @@ -36,9 +36,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context) const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX); const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX); gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX); - gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX); + gert::Shape *expertTokenNumsShape = context->GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS); if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr || - recvCountOutShape == nullptr) { + expertTokenNumsShape == nullptr) { return GRAPH_FAILED; } if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) { @@ -72,12 +72,12 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context) uint32_t epRankId = static_cast(*epRankIdPtr); uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); - recvCountOutShape->SetDimNum(1); + expertTokenNumsShape->SetDimNum(1); bool isShareExpert = (epRankId < sharedExpertRankNum); if (isShareExpert) { - recvCountOutShape->SetDim(0, epRankSize); + expertTokenNumsShape->SetDim(0, 1); } else { - recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum))); + expertTokenNumsShape->SetDim(0, moeExpertNum / (epRankSize - sharedExpertRankNum)); } return GRAPH_SUCCESS; @@ -87,7 +87,7 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) { const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX); context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType); - context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32); + context->SetOutputDataType(OUTPUT_EXPERT_TOKEN_NUMS, ge::DT_INT64); return ge::GRAPH_SUCCESS; } diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index 699b9e87..f4d3f430 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -58,6 +58,9 @@ constexpr uint32_t MIN_TOKEN_LENGTH = 512; constexpr uint32_t MAX_TOKEN_LENGTH = 7168; constexpr uint32_t MIN_GMM1_HIDDEN = 1024; constexpr uint32_t MAX_GMM1_HIDDEN = 6144; +constexpr uint32_t TENSOR_HIDDEN_INDEX = 1; +constexpr uint32_t SINGLE_HIDDEN_INDEX = 2; +constexpr uint32_t MAX_TENSOR_COUNT = 256; } // namespace namespace optiling { @@ -66,8 +69,176 @@ static size_t CeilUp(size_t x, size_t y) return (x + y - 1) / y * y; } -static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, - DispatchGmmCombineDecodeTilingData &tilingData) +static uint32_t CountTensorListLen(gert::TilingContext *context, int descIndex) +{ + int count = 0; + for (uint32_t i = 0; i < MAX_TENSOR_COUNT; i++) { + auto tensorElement = context->GetDynamicInputTensor(descIndex, i); + if (tensorElement == nullptr) { + break; + } + count++; + } + return count; +} + +static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint32_t gmm1ListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_INDEX); + auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0); + auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum(); + + OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm1ListLen > 1) { // List + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(gmm1ListLen != moeExpertNumPerRank, + OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number perRank."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(TENSOR_HIDDEN_INDEX); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = true; + } else { // Single + if (elementDims == 2) { // one localExpert perRank + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX - 1); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number per rank."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX); + } + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = false; + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + + uint32_t gmm1ScaleListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_SCALE_INDEX); + auto gmm1ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_SCALE_INDEX, 0); + auto gmm1ScaleFirstTensorElementShape = gmm1ScaleFirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm1ScaleFirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm1WeightScale shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm1ScaleListLen > 1) { // List + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 1) { // one localExpert perRank + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + uint32_t gmm2ListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_INDEX); + auto gmm2FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_INDEX, 0); + auto gmm2FirstTensorElementShape = gmm2FirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm2ListLen > 1) { // List + OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank, + OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2 does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2 does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 2) { // one localExpert perRank + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Weight does not match local expert num perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(2), + OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + + uint32_t gmm2ScaleListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_SCALE_INDEX); + auto gmm2ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_SCALE_INDEX, 0); + auto gmm2ScaleFirstTensorElementShape = gmm2ScaleFirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm2ScaleFirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm2WeightScale shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm2ScaleListLen > 1) { // List + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 1) { // one localExpert perRank + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckWeightTensorList(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + if (CheckGmm1Shape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm1ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm2Shape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm2ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS) { + return ge::GRAPH_SUCCESS; + } + return ge::GRAPH_FAILED; +} + +ge::graphStatus CheckXActiveMaskShape(gert::TilingContext *context, const char *nodeName, + DispatchGmmCombineDecodeTilingData &tilingData) { uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; @@ -76,55 +247,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; - uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank; - const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); - OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm1WeightDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX); - OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.", - gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1); - OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2, - OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX); - OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm2WeightDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX); - OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.", - gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1); - OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h), - return ge::GRAPH_FAILED); - const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( INPUT_SHARE_X_ACTIVE_MASK_INDEX); if (xActiveMaskStorageShape != nullptr) { @@ -312,21 +435,19 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK; OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); - const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); - OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."), - return ge::GRAPH_FAILED); - tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS); + OPS_ERR_IF(CheckWeightTensorList(context, tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckWeightTensorList failed."), return ge::GRAPH_FAILED); auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum; tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum; - OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E( - nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED); + OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED); + OPS_ERR_IF(CheckXActiveMaskShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckXActiveMaskShape failed."), return ge::GRAPH_FAILED); OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); SetHcommCfg(context, tilingData, groupEp); const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( INPUT_SHARE_X_ACTIVE_MASK_INDEX); @@ -338,6 +459,9 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) { tilingKey |= EXEC_FLAG_DEEP_FUSE; } + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) { + tilingKey |= EXEC_FLAG_TENSOR_LIST; + } context->SetTilingKey(tilingKey); context->SetBlockDim(aicNum); return ge::GRAPH_SUCCESS; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index 40d8e82e..02f8ada4 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -17,7 +17,7 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspace, GM_ADDR tiling) { @@ -25,10 +25,11 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); - if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) { + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { DispatchGmmCombineDecode op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, - expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data); + expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); } } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 74691f94..bedbc52b 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, - GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums, uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) @@ -138,7 +138,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun gmEpSendCount, xActiveMask, gmResvered, - gmOutputRecvCount, + gmExpertTokenNums, epRankSize, epRankId, moeExpertNum, @@ -244,7 +244,7 @@ public: GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); @@ -257,7 +257,7 @@ private: GM_ADDR gmWeight2_; GM_ADDR gmScale2_; GM_ADDR gmOutput_; - GM_ADDR gmOutputRecvCount_; + GM_ADDR gmExpertTokenNums_; GM_ADDR workspaceGM_; GM_ADDR gmSmoothScales_; GM_ADDR gmexpertScales_; @@ -296,7 +296,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { @@ -312,7 +312,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( gmWeight2_ = gmm2_weight; gmScale2_ = gmm2_weight_scale; gmOutput_ = output; - gmOutputRecvCount_ = outputRecvCount; + gmExpertTokenNums_ = expertTokenNums; workspaceGM_ = workspaceGM; gmexpertScales_ = expert_scales; xActiveMask_ = x_active_mask; @@ -396,7 +396,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() MoeDistributeDispatchImpl::CamMoeDistributeDispatch dispatcher; dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, - gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); + gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_); dispatcher.Process(); tpipe.Destroy(); icache_preload(8); @@ -416,7 +416,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered, - gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); AscendC::PipeBarrier(); Arch::CrossCoreFlag gmm1AivFinished{0}; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h index 22cfe2b1..35b3512c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h @@ -10,6 +10,7 @@ #ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP #define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" #include "../../raw_distributed/cam_moe_distribute_combine.h" #include "catlass/catlass.hpp" #include "catlass/arch/cross_core_sync.hpp" @@ -122,7 +123,10 @@ public: AscendC::GlobalTensor gmA; gmA.SetGlobalBuffer(params.ptrA); AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -139,6 +143,10 @@ public: uint32_t stageUsed = 0; uint32_t startCoreIdx = 0; for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; @@ -189,7 +197,9 @@ public: } gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } @@ -232,6 +242,12 @@ public: uint32_t stageId = 0; uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); @@ -241,12 +257,22 @@ public: LayoutPerTokenScale layoutPerTokenScale = params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); - EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, layoutScale, params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, params.ptrD + gmGroupOffsetD, layoutD}; + } blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -270,7 +296,9 @@ public: stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; } - gmGroupOffsetScale += inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 33d4c411..2420ba3e 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -9,6 +9,7 @@ */ #pragma once +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" #include "catlass/catlass.hpp" #include "catlass/arch/cross_core_sync.hpp" #include "catlass/arch/resource.hpp" @@ -416,7 +417,7 @@ public: GM_ADDR gmExpandIdx; GM_ADDR gmEpSendCount; GM_ADDR gmResvered; - GM_ADDR gmOutputRecvCount; + GM_ADDR gmExpertTokenNums; uint32_t epRankSize; uint32_t epRankId; @@ -440,7 +441,7 @@ public: LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_, - GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, + GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_, uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, uint32_t h) @@ -465,7 +466,7 @@ public: gmexpertIds(gmexpertIds_), gmExpandIdx(gmExpandIdx_), gmEpSendCount(gmEpSendCount_), - gmOutputRecvCount(gmOutputRecvCount_), + gmExpertTokenNums(gmExpertTokenNums_), gmXActiveMask(gmXActiveMask_), gmResvered(gmResvered_), epRankSize(epRankSize_), @@ -535,7 +536,10 @@ public: AscendC::GlobalTensor gmA; gmA.SetGlobalBuffer(params.ptrA); AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -555,6 +559,10 @@ public: static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part uint32_t target = 1; for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + groupIdx * GROUP_INFO_SIZE); // wait AIV recv needed tokens @@ -619,7 +627,9 @@ public: } gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } startCoreIdx = (startCoreIdx + coreLoops) % aicNum; } @@ -1087,7 +1097,7 @@ public: } CATLASS_DEVICE - void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount) + void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) { // calculate token index in output tensor int64_t subUbOffset = ubOffset; @@ -1113,15 +1123,6 @@ public: AscendC::WaitFlag(0); AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); - if (isRecvCore && recvCoreIdx == 0) { - AscendC::GlobalTensor recvCountTensor; - recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); - AscendC::DataCopyExtParams dataCopyParams = { - 1U, static_cast(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U}; - AscendC::SetFlag(0); - AscendC::WaitFlag(0); - AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast(), dataCopyParams); - } AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); AscendC::PipeBarrier(); AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, @@ -1222,7 +1223,7 @@ public: } CATLASS_DEVICE - void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount) + void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) { ubOffset = 0; RecvCount(ubOffset); @@ -1249,7 +1250,7 @@ public: if (startRankId < recvExpertNum) { // RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info - GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount); + GetCumSum(startRankId, recvExpertNum, ubOffset); RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); } @@ -1291,8 +1292,13 @@ public: uint32_t stageId = 0; uint32_t target = 1; uint32_t startCoreIdx = 0; - + AscendC::ListTensorDesc gmScaleListTensor; AscendC::GlobalTensor groupTokenNumStateTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { // just like AIC groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + @@ -1311,12 +1317,22 @@ public: LayoutPerTokenScale layoutPerTokenScale = wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = layout::RowMajor{currentM, n}; - EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, - layoutScale, - gmTokenScale + gmGroupOffsetPerTokenScale, - layoutPerTokenScale, - gmSwigluOutput + gmGroupOffsetD, - layoutD}; + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + } blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -1340,7 +1356,9 @@ public: stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; } - gmGroupOffsetScale += inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetD += currentM * n; @@ -1485,7 +1503,7 @@ public: } CATLASS_DEVICE - void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount) + void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums) { if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { // clean @@ -1504,19 +1522,32 @@ public: expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); AscendC::GlobalTensor sendCountsGlobal; sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::GlobalTensor nonCumSumExpertTokenNumsTensor; + nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums); + uint32_t tmpTokenNum = 0; for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { __asm__ __volatile__(""); AscendC::DataCacheCleanAndInvalid( sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); __asm__ __volatile__(""); + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum; + nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum); + tmpTokenNum = tokenNum; + __asm__ __volatile__(""); AscendC::DataCacheCleanAndInvalid( expertTokenNumsOutGMTensor_[localMoeIndex]); __asm__ __volatile__(""); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + nonCumSumExpertTokenNumsTensor[localMoeIndex]); + __asm__ __volatile__(""); } } } @@ -1531,8 +1562,7 @@ public: (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask); } if (isRecvCore) { - RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, - (GM_ADDR)params.gmOutputRecvCount); + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); } auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( @@ -1547,7 +1577,7 @@ public: AscendC::SyncAll(); AscendC::PipeBarrier(); - UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount); + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums); { // dynamic quant AscendC::GlobalTensor sendCountsGlobal; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index 328538f2..3874ffa2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -30,6 +30,7 @@ struct DispatchGmmCombineDecodeInfo { uint64_t totalUbSize; uint64_t totalWinSize; uint64_t gmm1HLen; + bool isTensorList; }; struct DispatchGmmCombineDecodeTilingData { @@ -70,6 +71,7 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0; constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); +constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1); constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2); #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index a8077b67..5ff705aa 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -636,10 +636,10 @@ std::tuple grouped_matmul_swiglu_quant_weigh std::tuple dispatch_gmm_combine_decode( const at::Tensor &x, const at::Tensor &expert_ids, - const at::Tensor &gmm1_permuted_weight, - const at::Tensor &gmm1_permuted_weight_scale, - const at::Tensor &gmm2_weight, - const at::Tensor &gmm2_weight_scale, + const at::TensorList &gmm1_permuted_weight, + const at::TensorList &gmm1_permuted_weight_scale, + const at::TensorList &gmm2_weight, + const at::TensorList &gmm2_weight_scale, const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, const c10::optional &x_active_mask, @@ -660,7 +660,8 @@ std::tuple dispatch_gmm_combine_decode( bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); - at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options()); + auto opts = expert_ids.options().dtype(at::kLong); + at::Tensor expert_token_nums = at::empty({num_local_experts}, opts); vector group_ep_chrs(group_ep.begin(), group_ep.end()); group_ep_chrs.push_back('\0'); @@ -689,8 +690,8 @@ std::tuple dispatch_gmm_combine_decode( global_bs, // output tensors output, - ep_recv_count); - return {output, ep_recv_count}; + expert_token_nums); + return {output, expert_token_nums}; } void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, @@ -1287,16 +1288,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant); ops.def( - "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," - " Tensor gmm1_permuted_weight_scale," - " Tensor gmm2_weight, Tensor gmm2_weight_scale," + "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor[] gmm1_permuted_weight," + " Tensor[] gmm1_permuted_weight_scale," + " Tensor[] gmm2_weight, Tensor[] gmm2_weight_scale," " Tensor expert_scales, Tensor? expert_smooth_scales=None," " Tensor? x_active_mask=None," " str group_ep=''," " int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0," " int shared_expert_num=1, int shared_expert_rank_num=0," " int quant_mode=0," - " int global_bs=0) -> (Tensor output, Tensor ep_recv_count)" + " int global_bs=0) -> (Tensor output, Tensor expert_token_nums)" ); ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index af2c237a..da902ed4 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -157,10 +157,10 @@ std::tuple grouped_matmul_swiglu_quant_weigh std::tuple dispatch_gmm_combine_decode_meta( const at::Tensor &x, const at::Tensor &expert_ids, - const at::Tensor &gmm1_permuted_weight, - const at::Tensor &gmm1_permuted_weight_scale, - const at::Tensor &gmm2_weight, - const at::Tensor &gmm2_weight_scale, + const at::TensorList &gmm1_permuted_weight, + const at::TensorList &gmm1_permuted_weight_scale, + const at::TensorList &gmm2_weight, + const at::TensorList &gmm2_weight_scale, const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, const c10::optional &x_active_mask, @@ -181,9 +181,10 @@ std::tuple dispatch_gmm_combine_decode_meta( bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); - at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta)); - - return {output, ep_recv_count}; + auto opts = expert_ids.options().dtype(at::kLong); + at::Tensor expert_token_nums = at::empty({num_local_experts}, opts.device(at::kMeta)); + + return {output, expert_token_nums}; } void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py index 03b4b964..547e5c1d 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py @@ -27,7 +27,8 @@ BASE_KWARGS = { "test_bfloat16": True, "enable_dynamic_bs": False, "test_graph": False, - "with_mc2_mask": False + "with_mc2_mask": False, + "dynamic_eplb": False } @@ -48,20 +49,6 @@ def permute_weight(w: torch.Tensor, tile_n): n).contiguous() -def from_inclusive_prefix_sum(pref): - if isinstance(pref, torch.Tensor): - if pref.numel() == 0: - return pref - return torch.cat([pref[:1], pref[1:] - pref[:-1]]) - - if not pref: - return [] - out = [pref[0]] - for i in range(1, len(pref)): - out.append(pref[i] - pref[i - 1]) - return out - - def output_to_file(rank_id): return False @@ -80,7 +67,8 @@ class DecodeMoeOps(torch.nn.Module): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__() self.ep_hcomm_info = ep_hcomm_info self.batch_size = batch_size @@ -95,6 +83,7 @@ class DecodeMoeOps(torch.nn.Module): shared_expert_rank_num) self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank self.ep_recv_count_size = self.local_expert_num * ep_world_size + self.dynamic_eplb = dynamic_eplb self.gmm1_weight = torch.empty([ self.local_expert_num, self.token_hidden_size, self.moe_intermediate_size * 2 @@ -152,12 +141,13 @@ class SmallOps(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num) + shared_expert_rank_num, dynamic_eplb) self.tp_hcomm_info = "" def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, @@ -232,7 +222,7 @@ class SmallOps(DecodeMoeOps): shared_expert_num=1, shared_expert_rank_num=self.shared_expert_rank_num, global_bs=self.batch_size * self.ep_world_size) - return (combine_output, ep_send_counts[:self.ep_recv_count_size]) + return (combine_output, expert_token_nums) class FusionOp(DecodeMoeOps): @@ -249,12 +239,13 @@ class FusionOp(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num) + shared_expert_rank_num, dynamic_eplb) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): @@ -278,6 +269,34 @@ class FusionOp(DecodeMoeOps): global_bs=self.batch_size * self.ep_world_size) return output + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale): + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + if self.dynamic_eplb: + self.gmm1_weight = [ + weight.clone() for weight in gmm1_weight.unbind(dim=0) + ] + self.gmm1_weight_scale_fp32 = [ + weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) + ] + self.gmm2_weight = [ + weight.clone() for weight in gmm2_weight.unbind(dim=0) + ] + self.gmm2_weight_scale_fp32 = [ + weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) + ] + else: + self.gmm1_weight = [gmm1_weight.clone()] + self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] + self.gmm2_weight = [gmm2_weight.clone()] + self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + def generate_datas(batch_size, token_hidden_size, @@ -362,7 +381,8 @@ def run_once(local_rank_id, test_bfloat16=True, enable_dynamic_bs=False, test_graph=False, - with_mc2_mask=False): + with_mc2_mask=False, + dynamic_eplb=False): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None global_rank_id = local_rank_id # 单机 @@ -396,10 +416,10 @@ def run_once(local_rank_id, weight_datas = [ data.npu() if data is not None else None for data in weight_datas ] - small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, - *parameter).npu() # type: ignore - fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, - *parameter).npu() # type: ignore + small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter, + dynamic_eplb).npu() # type: ignore + fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter, + dynamic_eplb).npu() # type: ignore if test_graph: config = torchair.CompilerConfig() config.mode = "reduce-overhead" @@ -411,7 +431,7 @@ def run_once(local_rank_id, dist.destroy_process_group() if log_file is not None: log_file.close() - small_op_count_output = from_inclusive_prefix_sum(small_op_count_output) + torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), fused_op_token_output[0:valid_token_num].cpu(), atol=2.0, @@ -431,9 +451,19 @@ def test_dispatch_gmm_combine_decode_base(): mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) +@torch.inference_mode() def test_dispatch_gmm_combine_decode_with_mc2_mask(): custom_kwargs = BASE_KWARGS custom_kwargs["with_mc2_mask"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + + +@torch.inference_mode() +def test_dispatch_gmm_combine_decode_dynamic_eplb(): + custom_kwargs = BASE_KWARGS + custom_kwargs["dynamic_eplb"] = True + ep_world_size = custom_kwargs["ep_world_size"] + custom_args = tuple(custom_kwargs.values()) + mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index be528453..06f5df1d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -242,15 +242,14 @@ def select_moe_comm_method(num_tokens: int, ascend_config = get_ascend_config() dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 - fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( - not dynamic_eplb) + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" + dispatch_ffn_combine_enable = get_ep_group().world_size <= 16 and ( + not is_draft_model) and (not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: fused_decode_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: - fused_decode_enable = fused_mc2_enable and get_ep_group( - ).world_size <= 16 and (not is_draft_model) + fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: fused_decode_enable = fused_mc2_enable and \ speculative_enable_dispatch_gmm_combine_decode(vllm_config) @@ -258,8 +257,7 @@ def select_moe_comm_method(num_tokens: int, else: fused_prefill_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: - fused_prefill_enable = fused_mc2_enable and get_ep_group( - ).world_size <= 16 and (not is_draft_model) + fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: fused_prefill_enable = False moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 4bd3987b..d4c8bf44 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -131,7 +131,7 @@ env_variables: Dict[str, Callable[[], Any]] = { # `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb. # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer - # with W8A8, non-dynamic-eplb. And MTP layer must be W8A8. + # with W8A8. And MTP layer must be W8A8. "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), # Whether to anbale balance scheduling diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 8aabcc3c..a3822c2f 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -54,12 +54,15 @@ class VllmEplbAdaptor(EplbAdaptor): self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ self.model.model.layers[i].mlp.experts.w2_weight_scale_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w13_weight_offset", - "w2_weight_scale_list", "w2_weight_offset" + "w2_weight_scale_list", "w2_weight_offset", + "w2_weight_scale_fp32_list" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -97,7 +100,8 @@ class VllmEplbAdaptor(EplbAdaptor): self.num_dense_layers) + ".mlp.experts." + name if name in [ "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w2_weight_scale_list" + "w13_weight_scale_fp32_list", "w2_weight_scale_list", + "w2_weight_scale_fp32_list" ]: expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() @@ -118,7 +122,7 @@ class VllmEplbAdaptor(EplbAdaptor): if name in [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", - "w2_weight_scale_list" + "w2_weight_scale_list", "w2_weight_scale_fp32_list" ]: per_expert_param.append( self.param_dict["model.layers." + str(layer_idx) + diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 07488f2d..8ad25e2f 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod): assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ "token_dispatcher must be an instance of TokenDispatcherWithMC2." + group_list_type = None + expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore @@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod): ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." - out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore + group_list_type = 1 + out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore x=hidden_states, expert_ids=topk_ids, - gmm1_permuted_weight=w1[0], - gmm1_permuted_weight_scale=w1_scale[0], - gmm2_weight=w2[0], - gmm2_weight_scale=w2_scale[0], + gmm1_permuted_weight=w1, + gmm1_permuted_weight_scale=w1_scale, + gmm2_weight=w2, + gmm2_weight_scale=w2_scale, expert_smooth_scales=None, expert_scales=topk_weights.to(torch.float32), group_ep=self.token_dispatcher.moe_all_to_all_group_name, @@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod): else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return FusedExpertsResult(routed_out=out) + return FusedExpertsResult(routed_out=out, + group_list_type=group_list_type, + expert_tokens=expert_tokens) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index cba58850..b2e92e6e 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -254,7 +254,8 @@ class AscendW8A8DynamicFusedMoEMethod: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list w2 = layer.w2_weight_list - w2_scale = layer.w2_weight_scale_list + w2_scale = layer.w2_weight_scale_fp32_list \ + if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list else: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] @@ -333,11 +334,16 @@ class AscendW8A8DynamicFusedMoEMethod: weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0) ] + layer.w2_weight_scale_fp32_list = [ + weight.clone() + for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0) + ] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w13_weight_scale_fp32 del layer.w2_weight_scale + del layer.w2_weight_scale_fp32 torch.npu.empty_cache()