From 25baf6df095f10c6af29fcafd56ad3ff3e83a9b2 Mon Sep 17 00:00:00 2001 From: wangyibo1005 <74347676+wangyibo1005@users.noreply.github.com> Date: Wed, 7 Jan 2026 11:23:42 +0800 Subject: [PATCH] [Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552) #### What this PR does / why we need it? This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers. This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of list. This operator support couting how many token each local expert recieves by expertTokensNum . - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/7157596103666ee7ccb7008acee8bff8a8ff1731 More info about this operator, please refer to RFC: issue https://github.com/vllm-project/vllm-ascend/issues/5476 --- .../aclnn_dispatch_gmm_combine_decode.cpp | 22 +- .../aclnn_dispatch_gmm_combine_decode.h | 10 +- .../dispatch_gmm_combine_decode_def.cpp | 12 +- .../dispatch_gmm_combine_decode_proto.cpp | 14 +- .../dispatch_gmm_combine_decode_tiling.cpp | 242 +++++++++++++----- .../op_kernel/dispatch_gmm_combine_decode.cpp | 7 +- .../op_kernel/dispatch_gmm_combine_decode.h | 16 +- ...m_per_token_dequant_multistage_workspace.h | 36 ++- ...equant_swiglu_quant_multistage_workspace.h | 88 ++++--- .../dispatch_gmm_combine_decode_tiling.h | 2 + csrc/torch_binding.cpp | 23 +- csrc/torch_binding_meta.cpp | 15 +- .../test_dispatch_gmm_combine_decode.py | 84 ++++-- vllm_ascend/ascend_forward_context.py | 12 +- vllm_ascend/envs.py | 2 +- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 10 +- vllm_ascend/ops/fused_moe/moe_comm_method.py | 17 +- vllm_ascend/quantization/w8a8_dynamic.py | 8 +- 18 files changed, 425 insertions(+), 195 deletions(-) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp index 309e179a..2f9efc9b 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp @@ -26,10 +26,10 @@ extern "C" { extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -42,7 +42,7 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor); extern aclnnStatus aclnnInnerDispatchGmmCombineDecode( @@ -54,10 +54,10 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecode( aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -70,14 +70,14 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor) { return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, - output, epRecvCount, workspaceSize, executor); + output, expertTokenNums, workspaceSize, executor); } aclnnStatus aclnnDispatchGmmCombineDecode( diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h index 6601e2a1..d632e8c9 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h @@ -19,10 +19,10 @@ extern "C" { __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, - const aclTensor *gmm1PermutedWeight, - const aclTensor *gmm1PermutedWeightScale, - const aclTensor *gmm2Weight, - const aclTensor *gmm2WeightScale, + const aclTensorList *gmm1PermutedWeight, + const aclTensorList *gmm1PermutedWeightScale, + const aclTensorList *gmm2Weight, + const aclTensorList *gmm2WeightScale, const aclTensor *expertScales, const aclTensor *expertSmoothScalesOptional, const aclTensor *xActiveMaskOptional, @@ -35,7 +35,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode int64_t quantMode, int64_t globalBs, const aclTensor *output, - const aclTensor *epRecvCount, + const aclTensor *expertTokenNums, uint64_t *workspaceSize, aclOpExecutor **executor); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 520c64c4..511f653c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -26,22 +26,22 @@ public: .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm1_permuted_weight_scale") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm2_weight_scale") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); @@ -65,9 +65,9 @@ public: .DataType({ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("ep_recv_count") + this->Output("expert_token_nums") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) + .DataType({ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp index a0f1d219..7267f2b2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp @@ -18,7 +18,7 @@ namespace ge { constexpr uint32_t EXPAND_X_INDEX = 0; constexpr uint32_t EXPERT_IDS_INDEX = 1; constexpr uint32_t OUTPUT_X_INDEX = 0; -constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1; +constexpr uint32_t OUTPUT_EXPERT_TOKEN_NUMS = 1; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; @@ -36,9 +36,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context) const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX); const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX); gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX); - gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX); + gert::Shape *expertTokenNumsShape = context->GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS); if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr || - recvCountOutShape == nullptr) { + expertTokenNumsShape == nullptr) { return GRAPH_FAILED; } if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) { @@ -72,12 +72,12 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context) uint32_t epRankId = static_cast(*epRankIdPtr); uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); - recvCountOutShape->SetDimNum(1); + expertTokenNumsShape->SetDimNum(1); bool isShareExpert = (epRankId < sharedExpertRankNum); if (isShareExpert) { - recvCountOutShape->SetDim(0, epRankSize); + expertTokenNumsShape->SetDim(0, 1); } else { - recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum))); + expertTokenNumsShape->SetDim(0, moeExpertNum / (epRankSize - sharedExpertRankNum)); } return GRAPH_SUCCESS; @@ -87,7 +87,7 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) { const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX); context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType); - context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32); + context->SetOutputDataType(OUTPUT_EXPERT_TOKEN_NUMS, ge::DT_INT64); return ge::GRAPH_SUCCESS; } diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index 699b9e87..f4d3f430 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -58,6 +58,9 @@ constexpr uint32_t MIN_TOKEN_LENGTH = 512; constexpr uint32_t MAX_TOKEN_LENGTH = 7168; constexpr uint32_t MIN_GMM1_HIDDEN = 1024; constexpr uint32_t MAX_GMM1_HIDDEN = 6144; +constexpr uint32_t TENSOR_HIDDEN_INDEX = 1; +constexpr uint32_t SINGLE_HIDDEN_INDEX = 2; +constexpr uint32_t MAX_TENSOR_COUNT = 256; } // namespace namespace optiling { @@ -66,8 +69,176 @@ static size_t CeilUp(size_t x, size_t y) return (x + y - 1) / y * y; } -static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, - DispatchGmmCombineDecodeTilingData &tilingData) +static uint32_t CountTensorListLen(gert::TilingContext *context, int descIndex) +{ + int count = 0; + for (uint32_t i = 0; i < MAX_TENSOR_COUNT; i++) { + auto tensorElement = context->GetDynamicInputTensor(descIndex, i); + if (tensorElement == nullptr) { + break; + } + count++; + } + return count; +} + +static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint32_t gmm1ListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_INDEX); + auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0); + auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum(); + + OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm1ListLen > 1) { // List + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(gmm1ListLen != moeExpertNumPerRank, + OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number perRank."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(TENSOR_HIDDEN_INDEX); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = true; + } else { // Single + if (elementDims == 2) { // one localExpert perRank + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX - 1); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm1FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number per rank."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = + gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX); + } + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = false; + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + + uint32_t gmm1ScaleListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_SCALE_INDEX); + auto gmm1ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_SCALE_INDEX, 0); + auto gmm1ScaleFirstTensorElementShape = gmm1ScaleFirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm1ScaleFirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm1WeightScale shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm1ScaleListLen > 1) { // List + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 1) { // one localExpert perRank + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm1ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm1Scale does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + uint32_t gmm2ListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_INDEX); + auto gmm2FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_INDEX, 0); + auto gmm2FirstTensorElementShape = gmm2FirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm2ListLen > 1) { // List + OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank, + OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2 does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2 does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 2) { // one localExpert perRank + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm2FirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Weight does not match local expert num perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(2), + OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + const char *nodeName = context->GetNodeName(); + uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + + uint32_t gmm2ScaleListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_SCALE_INDEX); + auto gmm2ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_SCALE_INDEX, 0); + auto gmm2ScaleFirstTensorElementShape = gmm2ScaleFirstTensorElement->GetOriginShape(); + uint32_t elementDims = gmm2ScaleFirstTensorElementShape.GetDimNum(); + OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm2WeightScale shape is invalid."), + return ge::GRAPH_FAILED); + if (gmm2ScaleListLen > 1) { // List + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } else { // Single + if (elementDims == 1) { // one localExpert perRank + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } else { // multi localExperts perRank + OPS_ERR_IF(moeExpertNumPerRank != gmm2ScaleFirstTensorElementShape.GetDim(0), + OPS_LOG_E(nodeName, "gmm2Scale does not match local expert number perRank."), return ge::GRAPH_FAILED); + OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(1), + OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED); + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckWeightTensorList(gert::TilingContext *context, + DispatchGmmCombineDecodeTilingData *tilingData) +{ + if (CheckGmm1Shape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm1ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm2Shape(context, tilingData) == ge::GRAPH_SUCCESS && + CheckGmm2ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS) { + return ge::GRAPH_SUCCESS; + } + return ge::GRAPH_FAILED; +} + +ge::graphStatus CheckXActiveMaskShape(gert::TilingContext *context, const char *nodeName, + DispatchGmmCombineDecodeTilingData &tilingData) { uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; @@ -76,55 +247,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; - uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank; - const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); - OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm1WeightDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX); - OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.", - gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1); - OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2, - OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX); - OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm2WeightDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - - const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX); - OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.", - gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0); - OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum, - OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."), - return ge::GRAPH_FAILED); - const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1); - OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h), - return ge::GRAPH_FAILED); - const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( INPUT_SHARE_X_ACTIVE_MASK_INDEX); if (xActiveMaskStorageShape != nullptr) { @@ -312,21 +435,19 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK; OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); - const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); - OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."), - return ge::GRAPH_FAILED); - tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS); + OPS_ERR_IF(CheckWeightTensorList(context, tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckWeightTensorList failed."), return ge::GRAPH_FAILED); auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum; tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum; - OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."), - return ge::GRAPH_FAILED); - OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E( - nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED); + OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED); + OPS_ERR_IF(CheckXActiveMaskShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "CheckXActiveMaskShape failed."), return ge::GRAPH_FAILED); OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, - OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); SetHcommCfg(context, tilingData, groupEp); const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape( INPUT_SHARE_X_ACTIVE_MASK_INDEX); @@ -338,6 +459,9 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) { tilingKey |= EXEC_FLAG_DEEP_FUSE; } + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) { + tilingKey |= EXEC_FLAG_TENSOR_LIST; + } context->SetTilingKey(tilingKey); context->SetBlockDim(aicNum); return ge::GRAPH_SUCCESS; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index 40d8e82e..02f8ada4 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -17,7 +17,7 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspace, GM_ADDR tiling) { @@ -25,10 +25,11 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); - if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) { + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { DispatchGmmCombineDecode op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, - expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data); + expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); } } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 74691f94..bedbc52b 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, - GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums, uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) @@ -138,7 +138,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun gmEpSendCount, xActiveMask, gmResvered, - gmOutputRecvCount, + gmExpertTokenNums, epRankSize, epRankId, moeExpertNum, @@ -244,7 +244,7 @@ public: GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); @@ -257,7 +257,7 @@ private: GM_ADDR gmWeight2_; GM_ADDR gmScale2_; GM_ADDR gmOutput_; - GM_ADDR gmOutputRecvCount_; + GM_ADDR gmExpertTokenNums_; GM_ADDR workspaceGM_; GM_ADDR gmSmoothScales_; GM_ADDR gmexpertScales_; @@ -296,7 +296,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, // output - GM_ADDR output, GM_ADDR outputRecvCount, + GM_ADDR output, GM_ADDR expertTokenNums, // system GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { @@ -312,7 +312,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( gmWeight2_ = gmm2_weight; gmScale2_ = gmm2_weight_scale; gmOutput_ = output; - gmOutputRecvCount_ = outputRecvCount; + gmExpertTokenNums_ = expertTokenNums; workspaceGM_ = workspaceGM; gmexpertScales_ = expert_scales; xActiveMask_ = x_active_mask; @@ -396,7 +396,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() MoeDistributeDispatchImpl::CamMoeDistributeDispatch dispatcher; dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, - gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); + gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_); dispatcher.Process(); tpipe.Destroy(); icache_preload(8); @@ -416,7 +416,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered, - gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); AscendC::PipeBarrier(); Arch::CrossCoreFlag gmm1AivFinished{0}; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h index 22cfe2b1..35b3512c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h @@ -10,6 +10,7 @@ #ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP #define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" #include "../../raw_distributed/cam_moe_distribute_combine.h" #include "catlass/catlass.hpp" #include "catlass/arch/cross_core_sync.hpp" @@ -122,7 +123,10 @@ public: AscendC::GlobalTensor gmA; gmA.SetGlobalBuffer(params.ptrA); AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -139,6 +143,10 @@ public: uint32_t stageUsed = 0; uint32_t startCoreIdx = 0; for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; @@ -189,7 +197,9 @@ public: } gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } @@ -232,6 +242,12 @@ public: uint32_t stageId = 0; uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); @@ -241,12 +257,22 @@ public: LayoutPerTokenScale layoutPerTokenScale = params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); - EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, layoutScale, params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, params.ptrD + gmGroupOffsetD, layoutD}; + } blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -270,7 +296,9 @@ public: stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; } - gmGroupOffsetScale += inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 33d4c411..2420ba3e 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -9,6 +9,7 @@ */ #pragma once +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" #include "catlass/catlass.hpp" #include "catlass/arch/cross_core_sync.hpp" #include "catlass/arch/resource.hpp" @@ -416,7 +417,7 @@ public: GM_ADDR gmExpandIdx; GM_ADDR gmEpSendCount; GM_ADDR gmResvered; - GM_ADDR gmOutputRecvCount; + GM_ADDR gmExpertTokenNums; uint32_t epRankSize; uint32_t epRankId; @@ -440,7 +441,7 @@ public: LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_, - GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, + GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_, uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, uint32_t h) @@ -465,7 +466,7 @@ public: gmexpertIds(gmexpertIds_), gmExpandIdx(gmExpandIdx_), gmEpSendCount(gmEpSendCount_), - gmOutputRecvCount(gmOutputRecvCount_), + gmExpertTokenNums(gmExpertTokenNums_), gmXActiveMask(gmXActiveMask_), gmResvered(gmResvered_), epRankSize(epRankSize_), @@ -535,7 +536,10 @@ public: AscendC::GlobalTensor gmA; gmA.SetGlobalBuffer(params.ptrA); AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } AscendC::GlobalTensor groupList; groupList.SetGlobalBuffer(params.ptrGroupList); @@ -555,6 +559,10 @@ public: static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part uint32_t target = 1; for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + groupIdx * GROUP_INFO_SIZE); // wait AIV recv needed tokens @@ -619,7 +627,9 @@ public: } gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } startCoreIdx = (startCoreIdx + coreLoops) % aicNum; } @@ -1087,7 +1097,7 @@ public: } CATLASS_DEVICE - void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount) + void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) { // calculate token index in output tensor int64_t subUbOffset = ubOffset; @@ -1113,15 +1123,6 @@ public: AscendC::WaitFlag(0); AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); - if (isRecvCore && recvCoreIdx == 0) { - AscendC::GlobalTensor recvCountTensor; - recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); - AscendC::DataCopyExtParams dataCopyParams = { - 1U, static_cast(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U}; - AscendC::SetFlag(0); - AscendC::WaitFlag(0); - AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast(), dataCopyParams); - } AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); AscendC::PipeBarrier(); AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, @@ -1222,7 +1223,7 @@ public: } CATLASS_DEVICE - void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount) + void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) { ubOffset = 0; RecvCount(ubOffset); @@ -1249,7 +1250,7 @@ public: if (startRankId < recvExpertNum) { // RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info - GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount); + GetCumSum(startRankId, recvExpertNum, ubOffset); RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); } @@ -1291,8 +1292,13 @@ public: uint32_t stageId = 0; uint32_t target = 1; uint32_t startCoreIdx = 0; - + AscendC::ListTensorDesc gmScaleListTensor; AscendC::GlobalTensor groupTokenNumStateTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { // just like AIC groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + @@ -1311,12 +1317,22 @@ public: LayoutPerTokenScale layoutPerTokenScale = wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); LayoutD layoutD = layout::RowMajor{currentM, n}; - EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, - layoutScale, - gmTokenScale + gmGroupOffsetPerTokenScale, - layoutPerTokenScale, - gmSwigluOutput + gmGroupOffsetD, - layoutD}; + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + } blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); blockEpilogue.UpdateParams(epilogueParams); uint32_t coreLoops = blockScheduler.GetCoreLoops(); @@ -1340,7 +1356,9 @@ public: stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; } - gmGroupOffsetScale += inGroupProblemShape.n(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); gmGroupOffsetD += currentM * n; @@ -1485,7 +1503,7 @@ public: } CATLASS_DEVICE - void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount) + void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums) { if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { // clean @@ -1504,19 +1522,32 @@ public: expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); AscendC::GlobalTensor sendCountsGlobal; sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::GlobalTensor nonCumSumExpertTokenNumsTensor; + nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums); + uint32_t tmpTokenNum = 0; for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { __asm__ __volatile__(""); AscendC::DataCacheCleanAndInvalid( sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); __asm__ __volatile__(""); + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum; + nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum); + tmpTokenNum = tokenNum; + __asm__ __volatile__(""); AscendC::DataCacheCleanAndInvalid( expertTokenNumsOutGMTensor_[localMoeIndex]); __asm__ __volatile__(""); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + nonCumSumExpertTokenNumsTensor[localMoeIndex]); + __asm__ __volatile__(""); } } } @@ -1531,8 +1562,7 @@ public: (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask); } if (isRecvCore) { - RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, - (GM_ADDR)params.gmOutputRecvCount); + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); } auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( @@ -1547,7 +1577,7 @@ public: AscendC::SyncAll(); AscendC::PipeBarrier(); - UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount); + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums); { // dynamic quant AscendC::GlobalTensor sendCountsGlobal; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index 328538f2..3874ffa2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -30,6 +30,7 @@ struct DispatchGmmCombineDecodeInfo { uint64_t totalUbSize; uint64_t totalWinSize; uint64_t gmm1HLen; + bool isTensorList; }; struct DispatchGmmCombineDecodeTilingData { @@ -70,6 +71,7 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0; constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); +constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1); constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2); #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index a8077b67..5ff705aa 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -636,10 +636,10 @@ std::tuple grouped_matmul_swiglu_quant_weigh std::tuple dispatch_gmm_combine_decode( const at::Tensor &x, const at::Tensor &expert_ids, - const at::Tensor &gmm1_permuted_weight, - const at::Tensor &gmm1_permuted_weight_scale, - const at::Tensor &gmm2_weight, - const at::Tensor &gmm2_weight_scale, + const at::TensorList &gmm1_permuted_weight, + const at::TensorList &gmm1_permuted_weight_scale, + const at::TensorList &gmm2_weight, + const at::TensorList &gmm2_weight_scale, const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, const c10::optional &x_active_mask, @@ -660,7 +660,8 @@ std::tuple dispatch_gmm_combine_decode( bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); - at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options()); + auto opts = expert_ids.options().dtype(at::kLong); + at::Tensor expert_token_nums = at::empty({num_local_experts}, opts); vector group_ep_chrs(group_ep.begin(), group_ep.end()); group_ep_chrs.push_back('\0'); @@ -689,8 +690,8 @@ std::tuple dispatch_gmm_combine_decode( global_bs, // output tensors output, - ep_recv_count); - return {output, ep_recv_count}; + expert_token_nums); + return {output, expert_token_nums}; } void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, @@ -1287,16 +1288,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant); ops.def( - "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," - " Tensor gmm1_permuted_weight_scale," - " Tensor gmm2_weight, Tensor gmm2_weight_scale," + "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor[] gmm1_permuted_weight," + " Tensor[] gmm1_permuted_weight_scale," + " Tensor[] gmm2_weight, Tensor[] gmm2_weight_scale," " Tensor expert_scales, Tensor? expert_smooth_scales=None," " Tensor? x_active_mask=None," " str group_ep=''," " int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0," " int shared_expert_num=1, int shared_expert_rank_num=0," " int quant_mode=0," - " int global_bs=0) -> (Tensor output, Tensor ep_recv_count)" + " int global_bs=0) -> (Tensor output, Tensor expert_token_nums)" ); ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index af2c237a..da902ed4 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -157,10 +157,10 @@ std::tuple grouped_matmul_swiglu_quant_weigh std::tuple dispatch_gmm_combine_decode_meta( const at::Tensor &x, const at::Tensor &expert_ids, - const at::Tensor &gmm1_permuted_weight, - const at::Tensor &gmm1_permuted_weight_scale, - const at::Tensor &gmm2_weight, - const at::Tensor &gmm2_weight_scale, + const at::TensorList &gmm1_permuted_weight, + const at::TensorList &gmm1_permuted_weight_scale, + const at::TensorList &gmm2_weight, + const at::TensorList &gmm2_weight_scale, const at::Tensor &expert_scales, const c10::optional &expert_smooth_scales, const c10::optional &x_active_mask, @@ -181,9 +181,10 @@ std::tuple dispatch_gmm_combine_decode_meta( bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); - at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta)); - - return {output, ep_recv_count}; + auto opts = expert_ids.options().dtype(at::kLong); + at::Tensor expert_token_nums = at::empty({num_local_experts}, opts.device(at::kMeta)); + + return {output, expert_token_nums}; } void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py index 03b4b964..547e5c1d 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops/test_dispatch_gmm_combine_decode.py @@ -27,7 +27,8 @@ BASE_KWARGS = { "test_bfloat16": True, "enable_dynamic_bs": False, "test_graph": False, - "with_mc2_mask": False + "with_mc2_mask": False, + "dynamic_eplb": False } @@ -48,20 +49,6 @@ def permute_weight(w: torch.Tensor, tile_n): n).contiguous() -def from_inclusive_prefix_sum(pref): - if isinstance(pref, torch.Tensor): - if pref.numel() == 0: - return pref - return torch.cat([pref[:1], pref[1:] - pref[:-1]]) - - if not pref: - return [] - out = [pref[0]] - for i in range(1, len(pref)): - out.append(pref[i] - pref[i - 1]) - return out - - def output_to_file(rank_id): return False @@ -80,7 +67,8 @@ class DecodeMoeOps(torch.nn.Module): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__() self.ep_hcomm_info = ep_hcomm_info self.batch_size = batch_size @@ -95,6 +83,7 @@ class DecodeMoeOps(torch.nn.Module): shared_expert_rank_num) self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank self.ep_recv_count_size = self.local_expert_num * ep_world_size + self.dynamic_eplb = dynamic_eplb self.gmm1_weight = torch.empty([ self.local_expert_num, self.token_hidden_size, self.moe_intermediate_size * 2 @@ -152,12 +141,13 @@ class SmallOps(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num) + shared_expert_rank_num, dynamic_eplb) self.tp_hcomm_info = "" def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, @@ -232,7 +222,7 @@ class SmallOps(DecodeMoeOps): shared_expert_num=1, shared_expert_rank_num=self.shared_expert_rank_num, global_bs=self.batch_size * self.ep_world_size) - return (combine_output, ep_send_counts[:self.ep_recv_count_size]) + return (combine_output, expert_token_nums) class FusionOp(DecodeMoeOps): @@ -249,12 +239,13 @@ class FusionOp(DecodeMoeOps): ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num=0): + shared_expert_rank_num=0, + dynamic_eplb=False): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num) + shared_expert_rank_num, dynamic_eplb) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): @@ -278,6 +269,34 @@ class FusionOp(DecodeMoeOps): global_bs=self.batch_size * self.ep_world_size) return output + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale): + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + if self.dynamic_eplb: + self.gmm1_weight = [ + weight.clone() for weight in gmm1_weight.unbind(dim=0) + ] + self.gmm1_weight_scale_fp32 = [ + weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) + ] + self.gmm2_weight = [ + weight.clone() for weight in gmm2_weight.unbind(dim=0) + ] + self.gmm2_weight_scale_fp32 = [ + weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) + ] + else: + self.gmm1_weight = [gmm1_weight.clone()] + self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] + self.gmm2_weight = [gmm2_weight.clone()] + self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + def generate_datas(batch_size, token_hidden_size, @@ -362,7 +381,8 @@ def run_once(local_rank_id, test_bfloat16=True, enable_dynamic_bs=False, test_graph=False, - with_mc2_mask=False): + with_mc2_mask=False, + dynamic_eplb=False): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None global_rank_id = local_rank_id # 单机 @@ -396,10 +416,10 @@ def run_once(local_rank_id, weight_datas = [ data.npu() if data is not None else None for data in weight_datas ] - small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, - *parameter).npu() # type: ignore - fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, - *parameter).npu() # type: ignore + small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter, + dynamic_eplb).npu() # type: ignore + fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter, + dynamic_eplb).npu() # type: ignore if test_graph: config = torchair.CompilerConfig() config.mode = "reduce-overhead" @@ -411,7 +431,7 @@ def run_once(local_rank_id, dist.destroy_process_group() if log_file is not None: log_file.close() - small_op_count_output = from_inclusive_prefix_sum(small_op_count_output) + torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), fused_op_token_output[0:valid_token_num].cpu(), atol=2.0, @@ -431,9 +451,19 @@ def test_dispatch_gmm_combine_decode_base(): mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) +@torch.inference_mode() def test_dispatch_gmm_combine_decode_with_mc2_mask(): custom_kwargs = BASE_KWARGS custom_kwargs["with_mc2_mask"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + + +@torch.inference_mode() +def test_dispatch_gmm_combine_decode_dynamic_eplb(): + custom_kwargs = BASE_KWARGS + custom_kwargs["dynamic_eplb"] = True + ep_world_size = custom_kwargs["ep_world_size"] + custom_args = tuple(custom_kwargs.values()) + mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index be528453..06f5df1d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -242,15 +242,14 @@ def select_moe_comm_method(num_tokens: int, ascend_config = get_ascend_config() dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 - fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( - not dynamic_eplb) + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" + dispatch_ffn_combine_enable = get_ep_group().world_size <= 16 and ( + not is_draft_model) and (not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: fused_decode_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: - fused_decode_enable = fused_mc2_enable and get_ep_group( - ).world_size <= 16 and (not is_draft_model) + fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: fused_decode_enable = fused_mc2_enable and \ speculative_enable_dispatch_gmm_combine_decode(vllm_config) @@ -258,8 +257,7 @@ def select_moe_comm_method(num_tokens: int, else: fused_prefill_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: - fused_prefill_enable = fused_mc2_enable and get_ep_group( - ).world_size <= 16 and (not is_draft_model) + fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: fused_prefill_enable = False moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 4bd3987b..d4c8bf44 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -131,7 +131,7 @@ env_variables: Dict[str, Callable[[], Any]] = { # `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb. # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer - # with W8A8, non-dynamic-eplb. And MTP layer must be W8A8. + # with W8A8. And MTP layer must be W8A8. "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), # Whether to anbale balance scheduling diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 8aabcc3c..a3822c2f 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -54,12 +54,15 @@ class VllmEplbAdaptor(EplbAdaptor): self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ self.model.model.layers[i].mlp.experts.w2_weight_scale_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w13_weight_offset", - "w2_weight_scale_list", "w2_weight_offset" + "w2_weight_scale_list", "w2_weight_offset", + "w2_weight_scale_fp32_list" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -97,7 +100,8 @@ class VllmEplbAdaptor(EplbAdaptor): self.num_dense_layers) + ".mlp.experts." + name if name in [ "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w2_weight_scale_list" + "w13_weight_scale_fp32_list", "w2_weight_scale_list", + "w2_weight_scale_fp32_list" ]: expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() @@ -118,7 +122,7 @@ class VllmEplbAdaptor(EplbAdaptor): if name in [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", - "w2_weight_scale_list" + "w2_weight_scale_list", "w2_weight_scale_fp32_list" ]: per_expert_param.append( self.param_dict["model.layers." + str(layer_idx) + diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 07488f2d..8ad25e2f 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod): assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ "token_dispatcher must be an instance of TokenDispatcherWithMC2." + group_list_type = None + expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore @@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod): ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." - out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore + group_list_type = 1 + out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore x=hidden_states, expert_ids=topk_ids, - gmm1_permuted_weight=w1[0], - gmm1_permuted_weight_scale=w1_scale[0], - gmm2_weight=w2[0], - gmm2_weight_scale=w2_scale[0], + gmm1_permuted_weight=w1, + gmm1_permuted_weight_scale=w1_scale, + gmm2_weight=w2, + gmm2_weight_scale=w2_scale, expert_smooth_scales=None, expert_scales=topk_weights.to(torch.float32), group_ep=self.token_dispatcher.moe_all_to_all_group_name, @@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod): else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return FusedExpertsResult(routed_out=out) + return FusedExpertsResult(routed_out=out, + group_list_type=group_list_type, + expert_tokens=expert_tokens) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index cba58850..b2e92e6e 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -254,7 +254,8 @@ class AscendW8A8DynamicFusedMoEMethod: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list w2 = layer.w2_weight_list - w2_scale = layer.w2_weight_scale_list + w2_scale = layer.w2_weight_scale_fp32_list \ + if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list else: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] @@ -333,11 +334,16 @@ class AscendW8A8DynamicFusedMoEMethod: weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0) ] + layer.w2_weight_scale_fp32_list = [ + weight.clone() + for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0) + ] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w13_weight_scale_fp32 del layer.w2_weight_scale + del layer.w2_weight_scale_fp32 torch.npu.empty_cache()