[Feature]EPLB:Adapt DispatchGmmCombineDecode operator to eplb tensor list and expert token numbers (#5552)

#### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to eplb tensor list and
expert token numbers.

This operator support gmm1, gmm2, gmm1Scale and gmm2Scale in format of
list.
This operator support couting how many token each local expert recieves
by expertTokensNum .


- vLLM version: v0.13.0
- vLLM main:
7157596103

More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
This commit is contained in:
wangyibo1005
2026-01-07 11:23:42 +08:00
committed by GitHub
parent 086c093347
commit 25baf6df09
18 changed files with 425 additions and 195 deletions

View File

@@ -26,10 +26,10 @@ extern "C" {
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -42,7 +42,7 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
@@ -54,10 +54,10 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -70,14 +70,14 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
output, expertTokenNums, workspaceSize, executor);
}
aclnnStatus aclnnDispatchGmmCombineDecode(

View File

@@ -19,10 +19,10 @@ extern "C" {
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensorList *gmm1PermutedWeight,
const aclTensorList *gmm1PermutedWeightScale,
const aclTensorList *gmm2Weight,
const aclTensorList *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *xActiveMaskOptional,
@@ -35,7 +35,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
const aclTensor *expertTokenNums,
uint64_t *workspaceSize,
aclOpExecutor **executor);

View File

@@ -26,22 +26,22 @@ public:
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
@@ -65,9 +65,9 @@ public:
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
this->Output("expert_token_nums")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.DataType({ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();

View File

@@ -18,7 +18,7 @@ namespace ge {
constexpr uint32_t EXPAND_X_INDEX = 0;
constexpr uint32_t EXPERT_IDS_INDEX = 1;
constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1;
constexpr uint32_t OUTPUT_EXPERT_TOKEN_NUMS = 1;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
@@ -36,9 +36,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX);
const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX);
gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX);
gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX);
gert::Shape *expertTokenNumsShape = context->GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS);
if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr ||
recvCountOutShape == nullptr) {
expertTokenNumsShape == nullptr) {
return GRAPH_FAILED;
}
if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) {
@@ -72,12 +72,12 @@ static ge::graphStatus InferShape(gert::InferShapeContext *context)
uint32_t epRankId = static_cast<uint32_t>(*epRankIdPtr);
uint32_t sharedExpertRankNum = static_cast<uint32_t>(*sharedExpertRankNumPtr);
recvCountOutShape->SetDimNum(1);
expertTokenNumsShape->SetDimNum(1);
bool isShareExpert = (epRankId < sharedExpertRankNum);
if (isShareExpert) {
recvCountOutShape->SetDim(0, epRankSize);
expertTokenNumsShape->SetDim(0, 1);
} else {
recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum)));
expertTokenNumsShape->SetDim(0, moeExpertNum / (epRankSize - sharedExpertRankNum));
}
return GRAPH_SUCCESS;
@@ -87,7 +87,7 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX);
context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType);
context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32);
context->SetOutputDataType(OUTPUT_EXPERT_TOKEN_NUMS, ge::DT_INT64);
return ge::GRAPH_SUCCESS;
}

View File

@@ -58,6 +58,9 @@ constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
constexpr uint32_t MIN_GMM1_HIDDEN = 1024;
constexpr uint32_t MAX_GMM1_HIDDEN = 6144;
constexpr uint32_t TENSOR_HIDDEN_INDEX = 1;
constexpr uint32_t SINGLE_HIDDEN_INDEX = 2;
constexpr uint32_t MAX_TENSOR_COUNT = 256;
} // namespace
namespace optiling {
@@ -66,8 +69,176 @@ static size_t CeilUp(size_t x, size_t y)
return (x + y - 1) / y * y;
}
static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
static uint32_t CountTensorListLen(gert::TilingContext *context, int descIndex)
{
int count = 0;
for (uint32_t i = 0; i < MAX_TENSOR_COUNT; i++) {
auto tensorElement = context->GetDynamicInputTensor(descIndex, i);
if (tensorElement == nullptr) {
break;
}
count++;
}
return count;
}
static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t gmm1ListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_INDEX);
auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0);
auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm1ListLen > 1) { // List
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1ListLen != moeExpertNumPerRank,
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number perRank."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(TENSOR_HIDDEN_INDEX);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = true;
} else { // Single
if (elementDims == 2) { // one localExpert perRank
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX - 1);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm1FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Weight does not match local expert number per rank."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm1FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm1Weight input length does not equals to token hidden size."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen =
gmm1FirstTensorElementShape.GetDim(SINGLE_HIDDEN_INDEX);
}
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList = false;
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t gmm1ScaleListLen = CountTensorListLen(context, INPUT_GMM1_WEIGHT_SCALE_INDEX);
auto gmm1ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_SCALE_INDEX, 0);
auto gmm1ScaleFirstTensorElementShape = gmm1ScaleFirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm1ScaleFirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm1WeightScale shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm1ScaleListLen > 1) { // List
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 1) { // one localExpert perRank
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm1ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm1Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n != gmm1ScaleFirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm1Scale length does not equals to gmm1 hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t gmm2ListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_INDEX);
auto gmm2FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_INDEX, 0);
auto gmm2FirstTensorElementShape = gmm2FirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm2ListLen > 1) { // List
OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank,
OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2 does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2 does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 2) { // one localExpert perRank
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm2FirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Weight does not match local expert num perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(n / 2 != gmm2FirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Weight does not equals to token hidden size."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2FirstTensorElementShape.GetDim(2),
OPS_LOG_E(nodeName, "gmm2Weight does not match half of gmm1 hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
const char *nodeName = context->GetNodeName();
uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint32_t gmm2ScaleListLen = CountTensorListLen(context, INPUT_GMM2_WEIGHT_SCALE_INDEX);
auto gmm2ScaleFirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM2_WEIGHT_SCALE_INDEX, 0);
auto gmm2ScaleFirstTensorElementShape = gmm2ScaleFirstTensorElement->GetOriginShape();
uint32_t elementDims = gmm2ScaleFirstTensorElementShape.GetDimNum();
OPS_ERR_IF(elementDims != 1 && elementDims != 2, OPS_LOG_E(nodeName, "gmm2WeightScale shape is invalid."),
return ge::GRAPH_FAILED);
if (gmm2ScaleListLen > 1) { // List
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
} else { // Single
if (elementDims == 1) { // one localExpert perRank
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
} else { // multi localExperts perRank
OPS_ERR_IF(moeExpertNumPerRank != gmm2ScaleFirstTensorElementShape.GetDim(0),
OPS_LOG_E(nodeName, "gmm2Scale does not match local expert number perRank."), return ge::GRAPH_FAILED);
OPS_ERR_IF(h != gmm2ScaleFirstTensorElementShape.GetDim(1),
OPS_LOG_E(nodeName, "gmm2Scale does not match token hidden size."), return ge::GRAPH_FAILED);
}
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckWeightTensorList(gert::TilingContext *context,
DispatchGmmCombineDecodeTilingData *tilingData)
{
if (CheckGmm1Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm1ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm2Shape(context, tilingData) == ge::GRAPH_SUCCESS &&
CheckGmm2ScaleShape(context, tilingData) == ge::GRAPH_SUCCESS) {
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_FAILED;
}
ge::graphStatus CheckXActiveMaskShape(gert::TilingContext *context, const char *nodeName,
DispatchGmmCombineDecodeTilingData &tilingData)
{
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
@@ -76,55 +247,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
return ge::GRAPH_FAILED);
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
@@ -312,21 +435,19 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."),
return ge::GRAPH_FAILED);
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
OPS_ERR_IF(CheckWeightTensorList(context, tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckWeightTensorList failed."), return ge::GRAPH_FAILED);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = aicNum;
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckData failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckXActiveMaskShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "CheckXActiveMaskShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
@@ -338,6 +459,9 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) {
tilingKey |= EXEC_FLAG_TENSOR_LIST;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;

View File

@@ -17,7 +17,7 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
GM_ADDR output, GM_ADDR expertTokenNums,
// system
GM_ADDR workspace, GM_ADDR tiling)
{
@@ -25,10 +25,11 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
op.Process();
}
}

View File

@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
@@ -138,7 +138,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
gmEpSendCount,
xActiveMask,
gmResvered,
gmOutputRecvCount,
gmExpertTokenNums,
epRankSize,
epRankId,
moeExpertNum,
@@ -244,7 +244,7 @@ public:
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
GM_ADDR output, GM_ADDR expertTokenNums,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process();
@@ -257,7 +257,7 @@ private:
GM_ADDR gmWeight2_;
GM_ADDR gmScale2_;
GM_ADDR gmOutput_;
GM_ADDR gmOutputRecvCount_;
GM_ADDR gmExpertTokenNums_;
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
@@ -296,7 +296,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
GM_ADDR output, GM_ADDR expertTokenNums,
// system
GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{
@@ -312,7 +312,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
gmWeight2_ = gmm2_weight;
gmScale2_ = gmm2_weight_scale;
gmOutput_ = output;
gmOutputRecvCount_ = outputRecvCount;
gmExpertTokenNums_ = expertTokenNums;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
xActiveMask_ = x_active_mask;
@@ -396,7 +396,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process();
tpipe.Destroy();
icache_preload(8);
@@ -416,7 +416,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>();
Arch::CrossCoreFlag gmm1AivFinished{0};

View File

@@ -10,6 +10,7 @@
#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP
#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h"
#include "../../raw_distributed/cam_moe_distribute_combine.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/cross_core_sync.hpp"
@@ -122,7 +123,10 @@ public:
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer(params.ptrA);
AscendC::GlobalTensor<ElementB> gmB;
gmB.SetGlobalBuffer(params.ptrB);
AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB));
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr<int32_t>(0)));
}
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
@@ -139,6 +143,10 @@ public:
uint32_t stageUsed = 0;
uint32_t startCoreIdx = 0;
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(
gmBlistTensorDesc.GetDataPtr<int32_t>(groupIdx)));
}
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()};
@@ -189,7 +197,9 @@ public:
}
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
}
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
}
@@ -232,6 +242,12 @@ public:
uint32_t stageId = 0;
uint32_t startCoreIdx = 0;
AscendC::ListTensorDesc gmScaleListTensor;
gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale));
__gm__ ElementScale* gmScalePtr;
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr<int32_t>(0));
}
for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) {
uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx)
: (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1));
@@ -241,12 +257,22 @@ public:
LayoutPerTokenScale layoutPerTokenScale =
params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN());
EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale,
EpilogueParams epilogueParams;
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(
gmScaleListTensor.GetDataPtr<int32_t>(groupIdx));
epilogueParams = EpilogueParams {
gmScalePtr, layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale,
params.ptrD + gmGroupOffsetD, layoutD};
} else {
epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale,
layoutScale,
params.ptrPerTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
params.ptrD + gmGroupOffsetD,
layoutD};
}
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
blockEpilogue.UpdateParams(epilogueParams);
uint32_t coreLoops = blockScheduler.GetCoreLoops();
@@ -270,7 +296,9 @@ public:
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
gmGroupOffsetScale += inGroupProblemShape.n();
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetScale += inGroupProblemShape.n();
}
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n();

View File

@@ -9,6 +9,7 @@
*/
#pragma once
#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h"
#include "catlass/catlass.hpp"
#include "catlass/arch/cross_core_sync.hpp"
#include "catlass/arch/resource.hpp"
@@ -416,7 +417,7 @@ public:
GM_ADDR gmExpandIdx;
GM_ADDR gmEpSendCount;
GM_ADDR gmResvered;
GM_ADDR gmOutputRecvCount;
GM_ADDR gmExpertTokenNums;
uint32_t epRankSize;
uint32_t epRankId;
@@ -440,7 +441,7 @@ public:
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_,
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
uint32_t h)
@@ -465,7 +466,7 @@ public:
gmexpertIds(gmexpertIds_),
gmExpandIdx(gmExpandIdx_),
gmEpSendCount(gmEpSendCount_),
gmOutputRecvCount(gmOutputRecvCount_),
gmExpertTokenNums(gmExpertTokenNums_),
gmXActiveMask(gmXActiveMask_),
gmResvered(gmResvered_),
epRankSize(epRankSize_),
@@ -535,7 +536,10 @@ public:
AscendC::GlobalTensor<ElementA> gmA;
gmA.SetGlobalBuffer(params.ptrA);
AscendC::GlobalTensor<ElementB> gmB;
gmB.SetGlobalBuffer(params.ptrB);
AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB));
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr<int32_t>(0)));
}
AscendC::GlobalTensor<ElementGroupList> groupList;
groupList.SetGlobalBuffer(params.ptrGroupList);
@@ -555,6 +559,10 @@ public:
static_cast<uint8_t>(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part
uint32_t target = 1;
for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) {
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(
gmBlistTensorDesc.GetDataPtr<int32_t>(groupIdx)));
}
groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) +
groupIdx * GROUP_INFO_SIZE);
// wait AIV recv needed tokens
@@ -619,7 +627,9 @@ public:
}
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
}
startCoreIdx = (startCoreIdx + coreLoops) % aicNum;
}
@@ -1087,7 +1097,7 @@ public:
}
CATLASS_DEVICE
void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount)
void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset)
{
// calculate token index in output tensor
int64_t subUbOffset = ubOffset;
@@ -1113,15 +1123,6 @@ public:
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM,
{1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt);
if (isRecvCore && recvCoreIdx == 0) {
AscendC::GlobalTensor<int32_t> recvCountTensor;
recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount);
AscendC::DataCopyExtParams dataCopyParams = {
1U, static_cast<uint32_t>(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U};
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(0);
AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast<int32_t>(), dataCopyParams);
}
AscendC::LocalTensor<float> workLocalTensor = resource.ubBuf.template GetBufferByByte<float>(subUbOffset);
AscendC::PipeBarrier<PIPE_V>();
AscendC::ReduceSum<float>(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor,
@@ -1222,7 +1223,7 @@ public:
}
CATLASS_DEVICE
void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount)
void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount)
{
ubOffset = 0;
RecvCount(ubOffset);
@@ -1249,7 +1250,7 @@ public:
if (startRankId < recvExpertNum) {
// RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info
GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount);
GetCumSum(startRankId, recvExpertNum, ubOffset);
RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset);
}
@@ -1291,8 +1292,13 @@ public:
uint32_t stageId = 0;
uint32_t target = 1;
uint32_t startCoreIdx = 0;
AscendC::ListTensorDesc gmScaleListTensor;
AscendC::GlobalTensor<int32_t> groupTokenNumStateTensor;
gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale));
__gm__ ElementScale* gmScalePtr;
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr<int32_t>(0));
}
for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) {
// just like AIC
groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) +
@@ -1311,12 +1317,22 @@ public:
LayoutPerTokenScale layoutPerTokenScale =
wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>());
LayoutD layoutD = layout::RowMajor{currentM, n};
EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale,
layoutScale,
gmTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
gmSwigluOutput + gmGroupOffsetD,
layoutD};
EpilogueParams epilogueParams;
if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) {
gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(
gmScaleListTensor.GetDataPtr<int32_t>(groupIdx));
epilogueParams = EpilogueParams {
gmScalePtr, layoutScale,
gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale,
gmSwigluOutput + gmGroupOffsetD, layoutD};
} else {
epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale,
layoutScale,
gmTokenScale + gmGroupOffsetPerTokenScale,
layoutPerTokenScale,
gmSwigluOutput + gmGroupOffsetD,
layoutD};
}
blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN());
blockEpilogue.UpdateParams(epilogueParams);
uint32_t coreLoops = blockScheduler.GetCoreLoops();
@@ -1340,7 +1356,9 @@ public:
stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0;
}
gmGroupOffsetScale += inGroupProblemShape.n();
if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) {
gmGroupOffsetScale += inGroupProblemShape.n();
}
gmGroupOffsetPerTokenScale += inGroupProblemShape.m();
gmGroupOffsetD += currentM * n;
@@ -1485,7 +1503,7 @@ public:
}
CATLASS_DEVICE
void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount)
void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums)
{
if (aivIdx == aiCoreGroupNum * subBlockNum - 1) {
// clean
@@ -1504,19 +1522,32 @@ public:
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList));
AscendC::GlobalTensor<int32_t> sendCountsGlobal;
sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount));
AscendC::GlobalTensor<int64_t> nonCumSumExpertTokenNumsTensor;
nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums);
uint32_t tmpTokenNum = 0;
for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) {
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<int32_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(
sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]);
__asm__ __volatile__("");
uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1);
expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum);
uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum;
nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum);
tmpTokenNum = tokenNum;
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<int64_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(
expertTokenNumsOutGMTensor_[localMoeIndex]);
__asm__ __volatile__("");
__asm__ __volatile__("");
AscendC::DataCacheCleanAndInvalid<int64_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
AscendC::DcciDst::CACHELINE_OUT>(
nonCumSumExpertTokenNumsTensor[localMoeIndex]);
__asm__ __volatile__("");
}
}
}
@@ -1531,8 +1562,7 @@ public:
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
}
if (isRecvCore) {
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
(GM_ADDR)params.gmOutputRecvCount);
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount);
}
auto gmSwigluOutput = reinterpret_cast<__gm__ float *>(
@@ -1547,7 +1577,7 @@ public:
AscendC::SyncAll<false>();
AscendC::PipeBarrier<PIPE_ALL>();
UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount);
UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums);
{
// dynamic quant
AscendC::GlobalTensor<int32_t> sendCountsGlobal;

View File

@@ -30,6 +30,7 @@ struct DispatchGmmCombineDecodeInfo {
uint64_t totalUbSize;
uint64_t totalWinSize;
uint64_t gmm1HLen;
bool isTensorList;
};
struct DispatchGmmCombineDecodeTilingData {
@@ -70,6 +71,7 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
constexpr uint32_t WORKSPACE_STAGES = 4;
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1);
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H

View File

@@ -636,10 +636,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
const at::Tensor &x,
const at::Tensor &expert_ids,
const at::Tensor &gmm1_permuted_weight,
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const at::TensorList &gmm1_permuted_weight,
const at::TensorList &gmm1_permuted_weight_scale,
const at::TensorList &gmm2_weight,
const at::TensorList &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &x_active_mask,
@@ -660,7 +660,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
auto opts = expert_ids.options().dtype(at::kLong);
at::Tensor expert_token_nums = at::empty({num_local_experts}, opts);
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
group_ep_chrs.push_back('\0');
@@ -689,8 +690,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
global_bs,
// output tensors
output,
ep_recv_count);
return {output, ep_recv_count};
expert_token_nums);
return {output, expert_token_nums};
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
@@ -1287,16 +1288,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant);
ops.def(
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
" Tensor gmm1_permuted_weight_scale,"
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor[] gmm1_permuted_weight,"
" Tensor[] gmm1_permuted_weight_scale,"
" Tensor[] gmm2_weight, Tensor[] gmm2_weight_scale,"
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
" Tensor? x_active_mask=None,"
" str group_ep='',"
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
" int shared_expert_num=1, int shared_expert_rank_num=0,"
" int quant_mode=0,"
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
" int global_bs=0) -> (Tensor output, Tensor expert_token_nums)"
);
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);

View File

@@ -157,10 +157,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
const at::Tensor &x,
const at::Tensor &expert_ids,
const at::Tensor &gmm1_permuted_weight,
const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale,
const at::TensorList &gmm1_permuted_weight,
const at::TensorList &gmm1_permuted_weight_scale,
const at::TensorList &gmm2_weight,
const at::TensorList &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &x_active_mask,
@@ -181,9 +181,10 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta));
return {output, ep_recv_count};
auto opts = expert_ids.options().dtype(at::kLong);
at::Tensor expert_token_nums = at::empty({num_local_experts}, opts.device(at::kMeta));
return {output, expert_token_nums};
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,

View File

@@ -27,7 +27,8 @@ BASE_KWARGS = {
"test_bfloat16": True,
"enable_dynamic_bs": False,
"test_graph": False,
"with_mc2_mask": False
"with_mc2_mask": False,
"dynamic_eplb": False
}
@@ -48,20 +49,6 @@ def permute_weight(w: torch.Tensor, tile_n):
n).contiguous()
def from_inclusive_prefix_sum(pref):
if isinstance(pref, torch.Tensor):
if pref.numel() == 0:
return pref
return torch.cat([pref[:1], pref[1:] - pref[:-1]])
if not pref:
return []
out = [pref[0]]
for i in range(1, len(pref)):
out.append(pref[i] - pref[i - 1])
return out
def output_to_file(rank_id):
return False
@@ -80,7 +67,8 @@ class DecodeMoeOps(torch.nn.Module):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__()
self.ep_hcomm_info = ep_hcomm_info
self.batch_size = batch_size
@@ -95,6 +83,7 @@ class DecodeMoeOps(torch.nn.Module):
shared_expert_rank_num)
self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
self.ep_recv_count_size = self.local_expert_num * ep_world_size
self.dynamic_eplb = dynamic_eplb
self.gmm1_weight = torch.empty([
self.local_expert_num, self.token_hidden_size,
self.moe_intermediate_size * 2
@@ -152,12 +141,13 @@ class SmallOps(DecodeMoeOps):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
shared_expert_rank_num, dynamic_eplb)
self.tp_hcomm_info = ""
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
@@ -232,7 +222,7 @@ class SmallOps(DecodeMoeOps):
shared_expert_num=1,
shared_expert_rank_num=self.shared_expert_rank_num,
global_bs=self.batch_size * self.ep_world_size)
return (combine_output, ep_send_counts[:self.ep_recv_count_size])
return (combine_output, expert_token_nums)
class FusionOp(DecodeMoeOps):
@@ -249,12 +239,13 @@ class FusionOp(DecodeMoeOps):
ep_world_size,
moe_expert_num,
global_rank_id,
shared_expert_rank_num=0):
shared_expert_rank_num=0,
dynamic_eplb=False):
super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight,
gmm2_weight_scale, ep_hcomm_info, batch_size,
token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
shared_expert_rank_num, dynamic_eplb)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
@@ -278,6 +269,34 @@ class FusionOp(DecodeMoeOps):
global_bs=self.batch_size * self.ep_world_size)
return output
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
if self.dynamic_eplb:
self.gmm1_weight = [
weight.clone() for weight in gmm1_weight.unbind(dim=0)
]
self.gmm1_weight_scale_fp32 = [
weight.clone() for weight in gmm1_weight_scale.unbind(dim=0)
]
self.gmm2_weight = [
weight.clone() for weight in gmm2_weight.unbind(dim=0)
]
self.gmm2_weight_scale_fp32 = [
weight.clone() for weight in gmm2_weight_scale.unbind(dim=0)
]
else:
self.gmm1_weight = [gmm1_weight.clone()]
self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()]
self.gmm2_weight = [gmm2_weight.clone()]
self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()]
def generate_datas(batch_size,
token_hidden_size,
@@ -362,7 +381,8 @@ def run_once(local_rank_id,
test_bfloat16=True,
enable_dynamic_bs=False,
test_graph=False,
with_mc2_mask=False):
with_mc2_mask=False,
dynamic_eplb=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
@@ -396,10 +416,10 @@ def run_once(local_rank_id,
weight_datas = [
data.npu() if data is not None else None for data in weight_datas
]
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small,
*parameter).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore
small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter,
dynamic_eplb).npu() # type: ignore
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter,
dynamic_eplb).npu() # type: ignore
if test_graph:
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
@@ -411,7 +431,7 @@ def run_once(local_rank_id,
dist.destroy_process_group()
if log_file is not None:
log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
fused_op_token_output[0:valid_token_num].cpu(),
atol=2.0,
@@ -431,9 +451,19 @@ def test_dispatch_gmm_combine_decode_base():
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
@torch.inference_mode()
def test_dispatch_gmm_combine_decode_with_mc2_mask():
custom_kwargs = BASE_KWARGS
custom_kwargs["with_mc2_mask"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
@torch.inference_mode()
def test_dispatch_gmm_combine_decode_dynamic_eplb():
custom_kwargs = BASE_KWARGS
custom_kwargs["dynamic_eplb"] = True
ep_world_size = custom_kwargs["ep_world_size"]
custom_args = tuple(custom_kwargs.values())
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)

View File

@@ -242,15 +242,14 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
not dynamic_eplb)
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 16 and (
not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_draft_model)
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and \
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
@@ -258,8 +257,7 @@ def select_moe_comm_method(num_tokens: int,
else:
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_draft_model)
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL

View File

@@ -131,7 +131,7 @@ env_variables: Dict[str, Callable[[], Any]] = {
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb.
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
# with W8A8, non-dynamic-eplb. And MTP layer must be W8A8.
# with W8A8. And MTP layer must be W8A8.
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
# Whether to anbale balance scheduling

View File

@@ -54,12 +54,15 @@ class VllmEplbAdaptor(EplbAdaptor):
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w13_weight_offset",
"w2_weight_scale_list", "w2_weight_offset"
"w2_weight_scale_list", "w2_weight_offset",
"w2_weight_scale_fp32_list"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
@@ -97,7 +100,8 @@ class VllmEplbAdaptor(EplbAdaptor):
self.num_dense_layers) + ".mlp.experts." + name
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
"w13_weight_scale_fp32_list", "w2_weight_scale_list",
"w2_weight_scale_fp32_list"
]:
expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
@@ -118,7 +122,7 @@ class VllmEplbAdaptor(EplbAdaptor):
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list"
"w2_weight_scale_list", "w2_weight_scale_fp32_list"
]:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) +

View File

@@ -300,6 +300,8 @@ class FusedMC2CommImpl(MoECommMethod):
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
group_list_type = None
expert_tokens = None
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
@@ -316,13 +318,14 @@ class FusedMC2CommImpl(MoECommMethod):
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None."
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
group_list_type = 1
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
gmm1_permuted_weight_scale=w1_scale[0],
gmm2_weight=w2[0],
gmm2_weight_scale=w2_scale[0],
gmm1_permuted_weight=w1,
gmm1_permuted_weight_scale=w1_scale,
gmm2_weight=w2,
gmm2_weight_scale=w2_scale,
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
@@ -333,4 +336,6 @@ class FusedMC2CommImpl(MoECommMethod):
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return FusedExpertsResult(routed_out=out)
return FusedExpertsResult(routed_out=out,
group_list_type=group_list_type,
expert_tokens=expert_tokens)

View File

@@ -254,7 +254,8 @@ class AscendW8A8DynamicFusedMoEMethod:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_list
w2_scale = layer.w2_weight_scale_fp32_list \
if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
@@ -333,11 +334,16 @@ class AscendW8A8DynamicFusedMoEMethod:
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
]
layer.w2_weight_scale_fp32_list = [
weight.clone()
for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
del layer.w2_weight_scale_fp32
torch.npu.empty_cache()