[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)
### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
|
|||||||
const aclTensor *gmm1PermutedWeightScale,
|
const aclTensor *gmm1PermutedWeightScale,
|
||||||
const aclTensor *gmm2Weight,
|
const aclTensor *gmm2Weight,
|
||||||
const aclTensor *gmm2WeightScale,
|
const aclTensor *gmm2WeightScale,
|
||||||
|
const aclTensor *expertScales,
|
||||||
const aclTensor *expertSmoothScalesOptional,
|
const aclTensor *expertSmoothScalesOptional,
|
||||||
const aclTensor *expertScalesOptional,
|
const aclTensor *xActiveMaskOptional,
|
||||||
char *groupEp,
|
char *groupEp,
|
||||||
int64_t epRankSize,
|
int64_t epRankSize,
|
||||||
int64_t epRankId,
|
int64_t epRankId,
|
||||||
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
|||||||
const aclTensor *gmm1PermutedWeightScale,
|
const aclTensor *gmm1PermutedWeightScale,
|
||||||
const aclTensor *gmm2Weight,
|
const aclTensor *gmm2Weight,
|
||||||
const aclTensor *gmm2WeightScale,
|
const aclTensor *gmm2WeightScale,
|
||||||
|
const aclTensor *expertScales,
|
||||||
const aclTensor *expertSmoothScalesOptional,
|
const aclTensor *expertSmoothScalesOptional,
|
||||||
const aclTensor *expertScalesOptional,
|
const aclTensor *xActiveMaskOptional,
|
||||||
char *groupEp,
|
char *groupEp,
|
||||||
int64_t epRankSize,
|
int64_t epRankSize,
|
||||||
int64_t epRankId,
|
int64_t epRankId,
|
||||||
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
|||||||
aclOpExecutor **executor)
|
aclOpExecutor **executor)
|
||||||
{
|
{
|
||||||
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
||||||
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
|
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
|
||||||
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
||||||
output, epRecvCount, workspaceSize, executor);
|
output, epRecvCount, workspaceSize, executor);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
|
|||||||
const aclTensor *gmm1PermutedWeightScale,
|
const aclTensor *gmm1PermutedWeightScale,
|
||||||
const aclTensor *gmm2Weight,
|
const aclTensor *gmm2Weight,
|
||||||
const aclTensor *gmm2WeightScale,
|
const aclTensor *gmm2WeightScale,
|
||||||
|
const aclTensor *expertScales,
|
||||||
const aclTensor *expertSmoothScalesOptional,
|
const aclTensor *expertSmoothScalesOptional,
|
||||||
const aclTensor *expertScalesOptional,
|
const aclTensor *xActiveMaskOptional,
|
||||||
char *groupEp,
|
char *groupEp,
|
||||||
int64_t epRankSize,
|
int64_t epRankSize,
|
||||||
int64_t epRankId,
|
int64_t epRankId,
|
||||||
|
|||||||
@@ -45,14 +45,19 @@ public:
|
|||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
this->Input("expert_scales")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||||
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("expert_smooth_scales")
|
this->Input("expert_smooth_scales")
|
||||||
.ParamType(OPTIONAL)
|
.ParamType(OPTIONAL)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("expert_scales")
|
this->Input("x_active_mask")
|
||||||
.ParamType(OPTIONAL)
|
.ParamType(OPTIONAL)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_BOOL, ge::DT_BOOL})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Output("output")
|
this->Output("output")
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
|
|||||||
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
||||||
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
||||||
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
||||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
|
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
|
||||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
|
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
|
||||||
|
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
|
||||||
|
|
||||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||||
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
|
|||||||
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
||||||
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
||||||
constexpr uint32_t SUPPORT_TOP_K = 12;
|
constexpr uint32_t SUPPORT_TOP_K = 12;
|
||||||
|
constexpr uint32_t ONE_DIMS = 1;
|
||||||
constexpr uint32_t TWO_DIMS = 2;
|
constexpr uint32_t TWO_DIMS = 2;
|
||||||
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
||||||
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
||||||
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
|||||||
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||||
|
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||||
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||||
|
|
||||||
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
|||||||
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
||||||
return ge::GRAPH_FAILED);
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||||
|
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||||
|
if (xActiveMaskStorageShape != nullptr) {
|
||||||
|
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
|
||||||
|
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
|
||||||
|
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
|
||||||
|
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
|
||||||
|
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
|||||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
||||||
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
||||||
return ge::GRAPH_FAILED);
|
return ge::GRAPH_FAILED);
|
||||||
|
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
|
||||||
|
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
|
||||||
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
||||||
SetHcommCfg(context, tilingData, groupEp);
|
SetHcommCfg(context, tilingData, groupEp);
|
||||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
|
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||||
context->SetTilingKey(0);
|
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||||
} else {
|
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
|
||||||
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
|
uint64_t tilingKey = 0;
|
||||||
|
if (xActiveMaskEnable) {
|
||||||
|
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
|
||||||
}
|
}
|
||||||
|
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
|
||||||
|
tilingKey |= EXEC_FLAG_DEEP_FUSE;
|
||||||
|
}
|
||||||
|
context->SetTilingKey(tilingKey);
|
||||||
context->SetBlockDim(aicNum);
|
context->SetBlockDim(aicNum);
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,8 @@
|
|||||||
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||||
// input
|
// input
|
||||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||||
|
GM_ADDR x_active_mask,
|
||||||
// output
|
// output
|
||||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||||
// system
|
// system
|
||||||
@@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
|||||||
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
|
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
|
||||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
|
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
|
||||||
GET_TILING_DATA(tiling_data, tiling);
|
GET_TILING_DATA(tiling_data, tiling);
|
||||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
|
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
|
||||||
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
||||||
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
||||||
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
||||||
op.Process();
|
op.Process();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||||
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
|
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
|
||||||
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
|
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
|
||||||
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
||||||
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
|
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
|
||||||
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
|
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
|
||||||
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
|
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
|
||||||
@@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
using GemmKernel = typename std::conditional<
|
using GemmKernel = typename std::conditional<
|
||||||
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
||||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
||||||
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
||||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
||||||
|
|
||||||
@@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
gmexpertIds,
|
gmexpertIds,
|
||||||
gmExpandIdx,
|
gmExpandIdx,
|
||||||
gmEpSendCount,
|
gmEpSendCount,
|
||||||
|
xActiveMask,
|
||||||
gmResvered,
|
gmResvered,
|
||||||
gmOutputRecvCount,
|
gmOutputRecvCount,
|
||||||
epRankSize,
|
epRankSize,
|
||||||
@@ -241,7 +242,7 @@ public:
|
|||||||
__aicore__ inline void Init(
|
__aicore__ inline void Init(
|
||||||
// input
|
// input
|
||||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
|
||||||
// output
|
// output
|
||||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||||
// system
|
// system
|
||||||
@@ -260,6 +261,7 @@ private:
|
|||||||
GM_ADDR workspaceGM_;
|
GM_ADDR workspaceGM_;
|
||||||
GM_ADDR gmSmoothScales_;
|
GM_ADDR gmSmoothScales_;
|
||||||
GM_ADDR gmexpertScales_;
|
GM_ADDR gmexpertScales_;
|
||||||
|
GM_ADDR xActiveMask_;
|
||||||
|
|
||||||
uint32_t maxTokenNum_{0};
|
uint32_t maxTokenNum_{0};
|
||||||
uint32_t gmm1OutputDim_{0};
|
uint32_t gmm1OutputDim_{0};
|
||||||
@@ -291,7 +293,8 @@ template <TemplateMC2TypeClass>
|
|||||||
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||||
// input
|
// input
|
||||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||||
|
GM_ADDR x_active_mask,
|
||||||
// output
|
// output
|
||||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||||
// system
|
// system
|
||||||
@@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
|||||||
gmOutputRecvCount_ = outputRecvCount;
|
gmOutputRecvCount_ = outputRecvCount;
|
||||||
workspaceGM_ = workspaceGM;
|
workspaceGM_ = workspaceGM;
|
||||||
gmexpertScales_ = expert_scales;
|
gmexpertScales_ = expert_scales;
|
||||||
|
xActiveMask_ = x_active_mask;
|
||||||
tilingData_ = tilingData;
|
tilingData_ = tilingData;
|
||||||
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||||
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||||
@@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
|||||||
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
|
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
|
||||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
|
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
|
||||||
|
|
||||||
if constexpr (EXEC_FLAG == 0) {
|
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
|
||||||
if constexpr (g_coreType == AscendC::AIV) {
|
if constexpr (g_coreType == AscendC::AIV) {
|
||||||
AscendC::TPipe tpipe;
|
AscendC::TPipe tpipe;
|
||||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
|
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
|
||||||
dispatcher;
|
dispatcher;
|
||||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||||
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||||
dispatcher.Process();
|
dispatcher.Process();
|
||||||
tpipe.Destroy();
|
tpipe.Destroy();
|
||||||
@@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
|||||||
Gmm1BlockScheduler>(
|
Gmm1BlockScheduler>(
|
||||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||||
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
||||||
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
|
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
|
||||||
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
|
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
@@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
|||||||
|
|
||||||
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
|
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
|
||||||
if (g_coreType == AscendC::AIV) {
|
if (g_coreType == AscendC::AIV) {
|
||||||
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
|
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
|
||||||
workspaceGM_, nullptr, tilingData_);
|
workspaceGM_, nullptr, tilingData_);
|
||||||
}
|
}
|
||||||
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
|
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
|
||||||
|
|||||||
@@ -353,7 +353,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
|
|||||||
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
|
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
||||||
class ElementGroupList_>
|
class ElementGroupList_>
|
||||||
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
|
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
|
||||||
{
|
{
|
||||||
@@ -411,6 +411,7 @@ public:
|
|||||||
GM_ADDR gmX;
|
GM_ADDR gmX;
|
||||||
GM_ADDR debugGm;
|
GM_ADDR debugGm;
|
||||||
GM_ADDR gmexpertIds;
|
GM_ADDR gmexpertIds;
|
||||||
|
GM_ADDR gmXActiveMask;
|
||||||
|
|
||||||
GM_ADDR gmExpandIdx;
|
GM_ADDR gmExpandIdx;
|
||||||
GM_ADDR gmEpSendCount;
|
GM_ADDR gmEpSendCount;
|
||||||
@@ -438,7 +439,7 @@ public:
|
|||||||
LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_,
|
LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_,
|
||||||
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
|
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
|
||||||
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
|
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
|
||||||
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_,
|
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
|
||||||
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
|
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
|
||||||
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
|
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
|
||||||
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
|
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
|
||||||
@@ -465,6 +466,7 @@ public:
|
|||||||
gmExpandIdx(gmExpandIdx_),
|
gmExpandIdx(gmExpandIdx_),
|
||||||
gmEpSendCount(gmEpSendCount_),
|
gmEpSendCount(gmEpSendCount_),
|
||||||
gmOutputRecvCount(gmOutputRecvCount_),
|
gmOutputRecvCount(gmOutputRecvCount_),
|
||||||
|
gmXActiveMask(gmXActiveMask_),
|
||||||
gmResvered(gmResvered_),
|
gmResvered(gmResvered_),
|
||||||
epRankSize(epRankSize_),
|
epRankSize(epRankSize_),
|
||||||
epRankId(epRankId_),
|
epRankId(epRankId_),
|
||||||
@@ -635,6 +637,39 @@ public:
|
|||||||
AscendC::SyncAll<false>();
|
AscendC::SyncAll<false>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset)
|
||||||
|
{
|
||||||
|
int64_t subUbOffset = ubOffset;
|
||||||
|
AscendC::LocalTensor<bool> maskInputTensor = (resource.ubBuf.template
|
||||||
|
GetBufferByByte<bool>(subUbOffset));
|
||||||
|
AscendC::LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.template ReinterpretCast<int8_t>();
|
||||||
|
subUbOffset += CEIL_UP(axisBS * sizeof(bool));
|
||||||
|
AscendC::LocalTensor<half> maskTmpTensor = (resource.ubBuf.template
|
||||||
|
GetBufferByByte<half>(subUbOffset));
|
||||||
|
subUbOffset += CEIL_UP(axisBS * sizeof(half));
|
||||||
|
AscendC::LocalTensor<half> sumOutTensor = (resource.ubBuf.template
|
||||||
|
GetBufferByByte<half>(subUbOffset));
|
||||||
|
subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE);
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<bool> xActiveMaskGMTensor;
|
||||||
|
xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask);
|
||||||
|
uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool));
|
||||||
|
|
||||||
|
AscendC::DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS * sizeof(bool)), 0U, 0U, 0U};
|
||||||
|
AscendC::DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
|
||||||
|
AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||||
|
AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS);
|
||||||
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
AscendC::SumParams params{1, axisBsAlignSize, axisBS};
|
||||||
|
AscendC::Sum(sumOutTensor, maskTmpTensor, params);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
||||||
|
activeMaskBsCnt = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||||
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset)
|
void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset)
|
||||||
{
|
{
|
||||||
@@ -778,8 +813,8 @@ public:
|
|||||||
void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale)
|
void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale)
|
||||||
{
|
{
|
||||||
uint32_t newAivId = sendCoreIdx - sendToMoeAivNum;
|
uint32_t newAivId = sendCoreIdx - sendToMoeAivNum;
|
||||||
uint32_t sendTokenNum = axisBS / sendToShareAivNum;
|
uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum;
|
||||||
uint32_t remainderTokenNum = axisBS % sendToShareAivNum;
|
uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum;
|
||||||
uint32_t startTokenId = sendTokenNum * newAivId;
|
uint32_t startTokenId = sendTokenNum * newAivId;
|
||||||
if (newAivId < remainderTokenNum) {
|
if (newAivId < remainderTokenNum) {
|
||||||
sendTokenNum += 1;
|
sendTokenNum += 1;
|
||||||
@@ -788,7 +823,7 @@ public:
|
|||||||
startTokenId += remainderTokenNum;
|
startTokenId += remainderTokenNum;
|
||||||
}
|
}
|
||||||
uint32_t endTokenId = startTokenId + sendTokenNum;
|
uint32_t endTokenId = startTokenId + sendTokenNum;
|
||||||
if (startTokenId >= axisBS) {
|
if (startTokenId >= activeMaskBsCnt) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -962,10 +997,13 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE void
|
CATLASS_DEVICE void
|
||||||
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx)
|
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask)
|
||||||
{
|
{
|
||||||
ubOffset = 0;
|
ubOffset = 0;
|
||||||
expertIdsCnt = axisBS * axisK;
|
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||||
|
TokenActiveMaskCal(gmXActiveMask, ubOffset);
|
||||||
|
}
|
||||||
|
expertIdsCnt = activeMaskBsCnt * axisK;
|
||||||
|
|
||||||
AscendC::GlobalTensor<int32_t> expertIdsGMTensor_;
|
AscendC::GlobalTensor<int32_t> expertIdsGMTensor_;
|
||||||
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds);
|
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds);
|
||||||
@@ -1372,6 +1410,7 @@ public:
|
|||||||
hCommuSize = hOutSize + scaleParamPad;
|
hCommuSize = hOutSize + scaleParamPad;
|
||||||
axisHCommu = hCommuSize / sizeof(int8_t);
|
axisHCommu = hCommuSize / sizeof(int8_t);
|
||||||
axisBS = params.bs;
|
axisBS = params.bs;
|
||||||
|
activeMaskBsCnt = axisBS;
|
||||||
axisK = params.topK;
|
axisK = params.topK;
|
||||||
uint32_t maxAxisBs = params.globalBs / epRankSize;
|
uint32_t maxAxisBs = params.globalBs / epRankSize;
|
||||||
|
|
||||||
@@ -1489,7 +1528,7 @@ public:
|
|||||||
AivInitState();
|
AivInitState();
|
||||||
if (isSendCore) {
|
if (isSendCore) {
|
||||||
SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA,
|
SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA,
|
||||||
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx);
|
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
|
||||||
}
|
}
|
||||||
if (isRecvCore) {
|
if (isRecvCore) {
|
||||||
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
|
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
|
||||||
@@ -1596,6 +1635,7 @@ private:
|
|||||||
uint32_t hCommuSize{0};
|
uint32_t hCommuSize{0};
|
||||||
uint32_t axisHCommu{0};
|
uint32_t axisHCommu{0};
|
||||||
uint32_t axisBS{0};
|
uint32_t axisBS{0};
|
||||||
|
uint32_t activeMaskBsCnt{0};
|
||||||
uint32_t axisK{0};
|
uint32_t axisK{0};
|
||||||
uint32_t totalTokenCount{0};
|
uint32_t totalTokenCount{0};
|
||||||
uint32_t expertIdsCnt{0};
|
uint32_t expertIdsCnt{0};
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class CamMoeDistributeCombine
|
|||||||
public:
|
public:
|
||||||
__aicore__ inline CamMoeDistributeCombine(){};
|
__aicore__ inline CamMoeDistributeCombine(){};
|
||||||
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
|
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
|
||||||
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
||||||
const DispatchGmmCombineDecodeTilingData *tilingData);
|
const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||||
__aicore__ inline void Process();
|
__aicore__ inline void Process();
|
||||||
__aicore__ inline void AllToAllSend();
|
__aicore__ inline void AllToAllSend();
|
||||||
@@ -145,6 +145,7 @@ private:
|
|||||||
GlobalTensor<ExpandIdxType> epSendCountGM_;
|
GlobalTensor<ExpandIdxType> epSendCountGM_;
|
||||||
GlobalTensor<ExpandIdxType> tpSendCountGM_;
|
GlobalTensor<ExpandIdxType> tpSendCountGM_;
|
||||||
GlobalTensor<float> expandScalesGM_;
|
GlobalTensor<float> expandScalesGM_;
|
||||||
|
GlobalTensor<bool> xActiveMaskGM_;
|
||||||
GlobalTensor<ExpandXType> expandOutGlobal_;
|
GlobalTensor<ExpandXType> expandOutGlobal_;
|
||||||
GlobalTensor<ExpandXType> rankWindow_;
|
GlobalTensor<ExpandXType> rankWindow_;
|
||||||
GlobalTensor<int32_t> rankStates_;
|
GlobalTensor<int32_t> rankStates_;
|
||||||
@@ -169,6 +170,8 @@ private:
|
|||||||
CombineCalcInfo calcInfo_;
|
CombineCalcInfo calcInfo_;
|
||||||
uint32_t axisBS_{0};
|
uint32_t axisBS_{0};
|
||||||
uint32_t axisMaxBs_{0};
|
uint32_t axisMaxBs_{0};
|
||||||
|
uint32_t axisBsAlignSize_{0};
|
||||||
|
uint64_t activeMaskBsCnt_{0};
|
||||||
uint32_t axisH_{0};
|
uint32_t axisH_{0};
|
||||||
uint32_t axisK_{0};
|
uint32_t axisK_{0};
|
||||||
uint32_t aivNum_{0};
|
uint32_t aivNum_{0};
|
||||||
@@ -224,6 +227,9 @@ private:
|
|||||||
TBuf<> gatherMaskOutBuf_; // gather mask output buf
|
TBuf<> gatherMaskOutBuf_; // gather mask output buf
|
||||||
TBuf<> gatherTmpBuf_;
|
TBuf<> gatherTmpBuf_;
|
||||||
TBuf<> statusSumOutBuf_;
|
TBuf<> statusSumOutBuf_;
|
||||||
|
TBuf<> xActMaskTBuf_;
|
||||||
|
TBuf<> xActMaskCastTBuf_;
|
||||||
|
TBuf<> xActMaskSumTBuf_;
|
||||||
float sumTarget_{0.0};
|
float sumTarget_{0.0};
|
||||||
int32_t epStateValue_;
|
int32_t epStateValue_;
|
||||||
bool isShardExpert_{false};
|
bool isShardExpert_{false};
|
||||||
@@ -231,7 +237,7 @@ private:
|
|||||||
|
|
||||||
template <TemplateMC2TypeClass>
|
template <TemplateMC2TypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
||||||
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales,
|
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask,
|
||||||
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||||
{
|
{
|
||||||
tpipe_ = pipe;
|
tpipe_ = pipe;
|
||||||
@@ -264,8 +270,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
|||||||
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
|
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
|
||||||
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
|
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
|
||||||
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
|
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
|
||||||
|
xActiveMaskGM_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
|
||||||
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
|
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
|
||||||
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||||
|
activeMaskBsCnt_ = axisBS_;
|
||||||
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||||
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||||
@@ -395,6 +403,27 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuf
|
|||||||
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
|
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
|
||||||
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
|
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
|
||||||
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
|
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
|
||||||
|
|
||||||
|
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||||
|
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
|
||||||
|
tpipe_->InitBuffer(xActMaskTBuf_, axisBsAlignSize_);
|
||||||
|
tpipe_->InitBuffer(xActMaskCastTBuf_, axisBsAlignSize_ * sizeof(half));
|
||||||
|
tpipe_->InitBuffer(xActMaskSumTBuf_, axisBsAlignSize_ * sizeof(half));
|
||||||
|
LocalTensor<bool> xActiveMaskTensor = xActMaskTBuf_.Get<bool>();
|
||||||
|
LocalTensor<half> tempTensor = xActMaskCastTBuf_.Get<half>();
|
||||||
|
LocalTensor<half> sumOutTensor = xActMaskSumTBuf_.Get<half>();
|
||||||
|
DataCopyExtParams xActiveMaskParams{1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
|
||||||
|
DataCopyPadExtParams<bool> xActiveMaskCopyPadParams{false, 0U, 0U, 0U};
|
||||||
|
DataCopyPad(xActiveMaskTensor, xActiveMaskGM_, xActiveMaskParams, xActiveMaskCopyPadParams);
|
||||||
|
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||||
|
LocalTensor<int8_t> xActiveMaskInt8Tensor = xActiveMaskTensor.ReinterpretCast<int8_t>();
|
||||||
|
Cast(tempTensor, xActiveMaskInt8Tensor, RoundMode::CAST_NONE, axisBS_);
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
|
SumParams params{1, axisBsAlignSize_, axisBS_};
|
||||||
|
Sum(sumOutTensor, tempTensor, params);
|
||||||
|
SyncFunc<AscendC::HardEvent::V_S>();
|
||||||
|
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <TemplateMC2TypeClass>
|
template <TemplateMC2TypeClass>
|
||||||
@@ -645,13 +674,16 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatc
|
|||||||
template <TemplateMC2TypeClass>
|
template <TemplateMC2TypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
|
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
|
||||||
{
|
{
|
||||||
|
if (activeMaskBsCnt_ == 0U) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
uint32_t beginIndex = 0;
|
uint32_t beginIndex = 0;
|
||||||
uint32_t endIndex = 0;
|
uint32_t endIndex = 0;
|
||||||
uint32_t processLen = 0;
|
uint32_t processLen = 0;
|
||||||
uint32_t tokenOffset = 0;
|
uint32_t tokenOffset = 0;
|
||||||
if (axisBS_ < aivNum_) {
|
if (activeMaskBsCnt_ < aivNum_) {
|
||||||
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_
|
uint32_t aivNumPerToken = aivNum_ / activeMaskBsCnt_; // activeMaskBsCnt_ < aivNum_
|
||||||
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) {
|
if (coreIdx_ >= (activeMaskBsCnt_ * aivNumPerToken)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
|
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
|
||||||
@@ -663,8 +695,8 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
|
|||||||
beginIndex = tokenIndex;
|
beginIndex = tokenIndex;
|
||||||
endIndex = beginIndex + 1U;
|
endIndex = beginIndex + 1U;
|
||||||
} else {
|
} else {
|
||||||
uint32_t tokenPerAivNum = axisBS_ / aivNum_;
|
uint32_t tokenPerAivNum = activeMaskBsCnt_ / aivNum_;
|
||||||
uint32_t remainderToken = axisBS_ % aivNum_;
|
uint32_t remainderToken = activeMaskBsCnt_ % aivNum_;
|
||||||
beginIndex = tokenPerAivNum * coreIdx_;
|
beginIndex = tokenPerAivNum * coreIdx_;
|
||||||
if (coreIdx_ < remainderToken) {
|
if (coreIdx_ < remainderToken) {
|
||||||
tokenPerAivNum++;
|
tokenPerAivNum++;
|
||||||
@@ -723,10 +755,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
|
|||||||
}
|
}
|
||||||
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
|
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
|
||||||
if (sharedExpertRankNum_ > 0U) {
|
if (sharedExpertRankNum_ > 0U) {
|
||||||
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
|
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
|
||||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
|
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
|
||||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
|
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
|
||||||
epRankId_ * axisBS_ / sharedExpertRankNum_;
|
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
|
||||||
__gm__ ExpandXType *shareAddr =
|
__gm__ ExpandXType *shareAddr =
|
||||||
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
|
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
|
||||||
(tokenIndex - preCnt) * axisH_ + tokenOffset;
|
(tokenIndex - preCnt) * axisH_ + tokenOffset;
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ constexpr uint8_t BUFFER_NUM = 2;
|
|||||||
constexpr uint32_t STATE_OFFSET = 512; // state space offset
|
constexpr uint32_t STATE_OFFSET = 512; // state space offset
|
||||||
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
|
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
|
||||||
constexpr uint32_t UB_ALIGN = 32;
|
constexpr uint32_t UB_ALIGN = 32;
|
||||||
|
constexpr uint64_t ALIGNED_LEN_256 = 256UL;
|
||||||
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||||
constexpr uint8_t COMM_NUM = 2;
|
constexpr uint8_t COMM_NUM = 2;
|
||||||
constexpr uint8_t COMM_EP_IDX = 0;
|
constexpr uint8_t COMM_EP_IDX = 0;
|
||||||
@@ -47,18 +48,13 @@ __aicore__ inline void SyncFunc()
|
|||||||
AscendC::WaitFlag<event>(eventID);
|
AscendC::WaitFlag<event>(eventID);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TemplateDispatchTypeClass \
|
|
||||||
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
|
||||||
bool IsNeedAllgater
|
|
||||||
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater
|
|
||||||
|
|
||||||
using namespace AscendC;
|
using namespace AscendC;
|
||||||
template <TemplateDispatchTypeClass>
|
template <TemplateDispatchTypeClass>
|
||||||
class CamMoeDistributeDispatch
|
class CamMoeDistributeDispatch
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
__aicore__ inline CamMoeDistributeDispatch(){};
|
__aicore__ inline CamMoeDistributeDispatch(){};
|
||||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut,
|
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut,
|
||||||
GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut,
|
GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut,
|
||||||
GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
||||||
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
|
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||||
@@ -80,6 +76,7 @@ private:
|
|||||||
__aicore__ inline void AllGatherSetStatusAndWait();
|
__aicore__ inline void AllGatherSetStatusAndWait();
|
||||||
__aicore__ inline void ResetStatus();
|
__aicore__ inline void ResetStatus();
|
||||||
__aicore__ inline void QuantInit(GM_ADDR scales);
|
__aicore__ inline void QuantInit(GM_ADDR scales);
|
||||||
|
__aicore__ inline void TokenActiveMaskCal();
|
||||||
__aicore__ inline void AllgatherProcessOut();
|
__aicore__ inline void AllgatherProcessOut();
|
||||||
__aicore__ inline void UpdataMultiMoeTokenNumsOut();
|
__aicore__ inline void UpdataMultiMoeTokenNumsOut();
|
||||||
__aicore__ inline void UpdataTokenNumsOut();
|
__aicore__ inline void UpdataTokenNumsOut();
|
||||||
@@ -112,6 +109,7 @@ private:
|
|||||||
GlobalTensor<XType> xGMTensor_;
|
GlobalTensor<XType> xGMTensor_;
|
||||||
GlobalTensor<int32_t> expertIdsGMTensor_;
|
GlobalTensor<int32_t> expertIdsGMTensor_;
|
||||||
GlobalTensor<float> scalesGMTensor_;
|
GlobalTensor<float> scalesGMTensor_;
|
||||||
|
GlobalTensor<bool> xActiveMaskGMTensor_;
|
||||||
GlobalTensor<ExpandXOutType> expandXOutGMTensor_;
|
GlobalTensor<ExpandXOutType> expandXOutGMTensor_;
|
||||||
GlobalTensor<float> dynamicScalesOutGMTensor_;
|
GlobalTensor<float> dynamicScalesOutGMTensor_;
|
||||||
GlobalTensor<int64_t> expertTokenNumsOutGMTensor_;
|
GlobalTensor<int64_t> expertTokenNumsOutGMTensor_;
|
||||||
@@ -144,6 +142,9 @@ private:
|
|||||||
TBuf<> rowMaxBuf_;
|
TBuf<> rowMaxBuf_;
|
||||||
TBuf<> receiveDataCastFloatBuf_;
|
TBuf<> receiveDataCastFloatBuf_;
|
||||||
TBuf<> smoothScalesBuf_;
|
TBuf<> smoothScalesBuf_;
|
||||||
|
TBuf<> dstExpBuf_;
|
||||||
|
TBuf<> subExpBuf_;
|
||||||
|
TBuf<> gatherMaskTBuf_;
|
||||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_;
|
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_;
|
||||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
TQue<QuePosition::VECIN, 1> xInQueue_;
|
||||||
TQue<QuePosition::VECOUT, 1> xOutQueue_;
|
TQue<QuePosition::VECOUT, 1> xOutQueue_;
|
||||||
@@ -161,7 +162,10 @@ private:
|
|||||||
GM_ADDR tpLocalStatusWindowGM_;
|
GM_ADDR tpLocalStatusWindowGM_;
|
||||||
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
|
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
|
||||||
uint32_t axisBS_{0};
|
uint32_t axisBS_{0};
|
||||||
|
uint32_t axisBsAlignSize_{0};
|
||||||
|
uint64_t activeMaskBsCnt_{0};
|
||||||
uint32_t axisMaxBS_{0};
|
uint32_t axisMaxBS_{0};
|
||||||
|
uint64_t sendToMoeExpTokenCnt_{0};
|
||||||
uint32_t axisH_{0};
|
uint32_t axisH_{0};
|
||||||
uint32_t axisK_{0};
|
uint32_t axisK_{0};
|
||||||
uint32_t aivNum_{0};
|
uint32_t aivNum_{0};
|
||||||
@@ -183,6 +187,7 @@ private:
|
|||||||
uint32_t hSize_{0};
|
uint32_t hSize_{0};
|
||||||
uint32_t hOutSize_{0};
|
uint32_t hOutSize_{0};
|
||||||
uint32_t hCommuSize_{0};
|
uint32_t hCommuSize_{0};
|
||||||
|
uint32_t maxSize_{0};
|
||||||
uint32_t scaleParamPad_{0};
|
uint32_t scaleParamPad_{0};
|
||||||
uint32_t axisHCommu_{0};
|
uint32_t axisHCommu_{0};
|
||||||
uint32_t startExpertId_;
|
uint32_t startExpertId_;
|
||||||
@@ -216,7 +221,7 @@ private:
|
|||||||
|
|
||||||
template <TemplateDispatchTypeClass>
|
template <TemplateDispatchTypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
||||||
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
||||||
GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
||||||
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||||
{
|
{
|
||||||
@@ -268,6 +273,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
|||||||
tpWorldSize_ = 1;
|
tpWorldSize_ = 1;
|
||||||
xGMTensor_.SetGlobalBuffer((__gm__ XType *)x);
|
xGMTensor_.SetGlobalBuffer((__gm__ XType *)x);
|
||||||
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds);
|
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds);
|
||||||
|
xActiveMaskGMTensor_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
|
||||||
expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut);
|
expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut);
|
||||||
dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
|
dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
|
||||||
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut);
|
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut);
|
||||||
@@ -329,7 +335,8 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
|||||||
if (isQuant_) {
|
if (isQuant_) {
|
||||||
QuantInit(scales);
|
QuantInit(scales);
|
||||||
}
|
}
|
||||||
uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
||||||
|
uint32_t expertIdsSize = Ceil(expertIdsCnt * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
||||||
tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize);
|
tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize);
|
||||||
expertIdsTensor_ = expertIdsBuf_.Get<int32_t>();
|
expertIdsTensor_ = expertIdsBuf_.Get<int32_t>();
|
||||||
tpipe_->InitBuffer(expertCountBuf_, expertIdsSize);
|
tpipe_->InitBuffer(expertCountBuf_, expertIdsSize);
|
||||||
@@ -339,7 +346,22 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
|||||||
tpipe_->InitBuffer(getTotalBuf_,
|
tpipe_->InitBuffer(getTotalBuf_,
|
||||||
epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t));
|
epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t));
|
||||||
tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2);
|
tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2);
|
||||||
|
uint32_t hFp32Size = axisH_ * sizeof(float);
|
||||||
|
uint32_t xActivateMaskSize = axisBS_ * (Ceil(axisK_ * sizeof(bool), UB_ALIGN) * UB_ALIGN) * sizeof(half);
|
||||||
|
uint32_t bsAlign256 = Ceil(axisBS_ * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
|
||||||
|
uint32_t bsKAlign256 = Ceil(expertIdsCnt * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
|
||||||
|
expertIdsSize = Ceil(expertIdsSize, UB_ALIGN) * UB_ALIGN;
|
||||||
|
maxSize_ = hFp32Size > expertIdsSize ? hFp32Size : expertIdsSize;
|
||||||
|
maxSize_ = maxSize_ > xActivateMaskSize ? maxSize_ : xActivateMaskSize;
|
||||||
|
maxSize_ = maxSize_ > bsKAlign256 ? maxSize_ : bsKAlign256;
|
||||||
|
tpipe_->InitBuffer(gatherMaskTBuf_, maxSize_);
|
||||||
|
if constexpr (DynamicQuant || StaticQuant) {
|
||||||
|
dstExpBuf_ = receiveDataCastFloatBuf_; // 内存复用
|
||||||
|
subExpBuf_ = smoothScalesBuf_; // 内存复用
|
||||||
|
} else {
|
||||||
|
tpipe_->InitBuffer(dstExpBuf_, maxSize_); // BS * K * 4 = 32K
|
||||||
|
tpipe_->InitBuffer(subExpBuf_, maxSize_); // BS * K * 4 = 32K
|
||||||
|
}
|
||||||
moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK;
|
moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK;
|
||||||
if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK &&
|
if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK &&
|
||||||
axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) {
|
axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) {
|
||||||
@@ -370,11 +392,36 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Quant
|
|||||||
dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>();
|
dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <TemplateDispatchTypeClass>
|
||||||
|
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::TokenActiveMaskCal()
|
||||||
|
{
|
||||||
|
// 搬运x_active_mask, 当前仅用于计算有效token总数
|
||||||
|
LocalTensor<half> maskTmpTensor;
|
||||||
|
LocalTensor<half> sumOutTensor;
|
||||||
|
LocalTensor<bool> maskInputTensor;
|
||||||
|
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
|
||||||
|
maskInputTensor = dstExpBuf_.Get<bool>();
|
||||||
|
maskTmpTensor = subExpBuf_.Get<half>();
|
||||||
|
sumOutTensor = gatherMaskTBuf_.Get<half>();
|
||||||
|
DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
|
||||||
|
DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
|
||||||
|
DataCopyPad(maskInputTensor, xActiveMaskGMTensor_, maskParams, maskCopyPadParams);
|
||||||
|
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||||
|
LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.ReinterpretCast<int8_t>();
|
||||||
|
Cast(maskTmpTensor, maskInputInt8Tensor, RoundMode::CAST_NONE, axisBS_);
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
|
SumParams params{1, axisBsAlignSize_, axisBS_};
|
||||||
|
Sum(sumOutTensor, maskTmpTensor, params);
|
||||||
|
SyncFunc<AscendC::HardEvent::V_S>();
|
||||||
|
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||||
|
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
|
||||||
|
}
|
||||||
|
|
||||||
template <TemplateDispatchTypeClass>
|
template <TemplateDispatchTypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert()
|
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert()
|
||||||
{
|
{
|
||||||
uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_;
|
uint32_t sendTokenNum = activeMaskBsCnt_ / sharedUsedAivNum_;
|
||||||
uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_;
|
uint32_t remainderTokenNum = activeMaskBsCnt_ % sharedUsedAivNum_;
|
||||||
uint32_t newAivId = aivId_ - moeUsedAivNum_;
|
uint32_t newAivId = aivId_ - moeUsedAivNum_;
|
||||||
uint32_t startTokenId = sendTokenNum * newAivId;
|
uint32_t startTokenId = sendTokenNum * newAivId;
|
||||||
if (newAivId < remainderTokenNum) {
|
if (newAivId < remainderTokenNum) {
|
||||||
@@ -383,16 +430,16 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
|||||||
} else {
|
} else {
|
||||||
startTokenId += remainderTokenNum;
|
startTokenId += remainderTokenNum;
|
||||||
}
|
}
|
||||||
if (startTokenId >= axisBS_) {
|
if (startTokenId >= activeMaskBsCnt_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uint32_t endTokenId = startTokenId + sendTokenNum;
|
uint32_t endTokenId = startTokenId + sendTokenNum;
|
||||||
for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) {
|
for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) {
|
||||||
uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum);
|
uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum);
|
||||||
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
|
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
|
||||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
|
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
|
||||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
|
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
|
||||||
epRankId_ * axisBS_ / sharedExpertRankNum_;
|
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
|
||||||
GlobalTensor<ExpandXOutType> dstWinGMTensor;
|
GlobalTensor<ExpandXOutType> dstWinGMTensor;
|
||||||
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) +
|
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) +
|
||||||
expertPerSizeOnWin_ * epRankId_));
|
expertPerSizeOnWin_ * epRankId_));
|
||||||
@@ -440,7 +487,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
|||||||
template <TemplateDispatchTypeClass>
|
template <TemplateDispatchTypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert()
|
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert()
|
||||||
{
|
{
|
||||||
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
|
||||||
uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_;
|
uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_;
|
||||||
uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_;
|
uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_;
|
||||||
uint32_t startTokenId = sendTokenNum * aivId_;
|
uint32_t startTokenId = sendTokenNum * aivId_;
|
||||||
@@ -496,7 +543,12 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
|||||||
template <TemplateDispatchTypeClass>
|
template <TemplateDispatchTypeClass>
|
||||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch()
|
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch()
|
||||||
{
|
{
|
||||||
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
activeMaskBsCnt_ = axisBS_;
|
||||||
|
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
|
||||||
|
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||||
|
TokenActiveMaskCal();
|
||||||
|
}
|
||||||
|
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
|
||||||
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U};
|
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||||
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
|
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
|
||||||
DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams);
|
DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams);
|
||||||
|
|||||||
@@ -14,5 +14,8 @@
|
|||||||
|
|
||||||
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
||||||
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
||||||
|
#define TemplateDispatchTypeClass \
|
||||||
|
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
||||||
|
bool IsNeedAllgater, uint32_t EXEC_FLAG
|
||||||
|
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG
|
||||||
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||||
|
|||||||
@@ -70,5 +70,6 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
|
|||||||
constexpr uint32_t WORKSPACE_STAGES = 4;
|
constexpr uint32_t WORKSPACE_STAGES = 4;
|
||||||
|
|
||||||
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
|
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
|
||||||
|
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
|
||||||
|
|
||||||
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||||
|
|||||||
@@ -640,8 +640,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
|||||||
const at::Tensor &gmm1_permuted_weight_scale,
|
const at::Tensor &gmm1_permuted_weight_scale,
|
||||||
const at::Tensor &gmm2_weight,
|
const at::Tensor &gmm2_weight,
|
||||||
const at::Tensor &gmm2_weight_scale,
|
const at::Tensor &gmm2_weight_scale,
|
||||||
|
const at::Tensor &expert_scales,
|
||||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||||
const c10::optional<at::Tensor> &expert_scales,
|
const c10::optional<at::Tensor> &x_active_mask,
|
||||||
c10::string_view group_ep,
|
c10::string_view group_ep,
|
||||||
int64_t ep_rank_size,
|
int64_t ep_rank_size,
|
||||||
int64_t ep_rank_id,
|
int64_t ep_rank_id,
|
||||||
@@ -674,8 +675,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
|||||||
gmm1_permuted_weight_scale,
|
gmm1_permuted_weight_scale,
|
||||||
gmm2_weight,
|
gmm2_weight,
|
||||||
gmm2_weight_scale,
|
gmm2_weight_scale,
|
||||||
expert_smooth_scales,
|
|
||||||
expert_scales,
|
expert_scales,
|
||||||
|
expert_smooth_scales,
|
||||||
|
x_active_mask,
|
||||||
//input attrs
|
//input attrs
|
||||||
group_ep_ptr,
|
group_ep_ptr,
|
||||||
ep_rank_size,
|
ep_rank_size,
|
||||||
@@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||||
" Tensor gmm1_permuted_weight_scale,"
|
" Tensor gmm1_permuted_weight_scale,"
|
||||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||||
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
|
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
|
||||||
|
" Tensor? x_active_mask=None,"
|
||||||
" str group_ep='',"
|
" str group_ep='',"
|
||||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||||
|
|||||||
@@ -161,8 +161,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
|
|||||||
const at::Tensor &gmm1_permuted_weight_scale,
|
const at::Tensor &gmm1_permuted_weight_scale,
|
||||||
const at::Tensor &gmm2_weight,
|
const at::Tensor &gmm2_weight,
|
||||||
const at::Tensor &gmm2_weight_scale,
|
const at::Tensor &gmm2_weight_scale,
|
||||||
|
const at::Tensor &expert_scales,
|
||||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||||
const c10::optional<at::Tensor> &expert_scales,
|
const c10::optional<at::Tensor> &x_active_mask,
|
||||||
c10::string_view group_ep,
|
c10::string_view group_ep,
|
||||||
int64_t ep_rank_size,
|
int64_t ep_rank_size,
|
||||||
int64_t ep_rank_id,
|
int64_t ep_rank_id,
|
||||||
|
|||||||
@@ -16,6 +16,19 @@ torch.manual_seed(42)
|
|||||||
torch_npu.npu.config.allow_internal_format = True
|
torch_npu.npu.config.allow_internal_format = True
|
||||||
enable_custom_op()
|
enable_custom_op()
|
||||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||||
|
BASE_KWARGS = {
|
||||||
|
"batch_size": 64,
|
||||||
|
"token_hidden_size": 7168,
|
||||||
|
"moe_intermediate_size": 2048,
|
||||||
|
"ep_world_size": 16,
|
||||||
|
"moe_expert_num": 64,
|
||||||
|
"shared_expert_rank_num": 0,
|
||||||
|
"top_k": 8,
|
||||||
|
"test_bfloat16": True,
|
||||||
|
"enable_dynamic_bs": False,
|
||||||
|
"test_graph": False,
|
||||||
|
"with_mc2_mask": False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def redirect_output(log_file_path):
|
def redirect_output(log_file_path):
|
||||||
@@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module):
|
|||||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||||
gmm2_weight_scale.float(), requires_grad=False)
|
gmm2_weight_scale.float(), requires_grad=False)
|
||||||
|
|
||||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||||
|
x_active_mask):
|
||||||
raise NotImplementedError("To be implemented in subclass")
|
raise NotImplementedError("To be implemented in subclass")
|
||||||
|
|
||||||
def forward(self, x, expert_ids, smooth_scales, expert_scales):
|
def forward(self, x, expert_ids, smooth_scales, expert_scales,
|
||||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
|
x_active_mask):
|
||||||
|
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
|
||||||
|
x_active_mask)
|
||||||
|
|
||||||
|
|
||||||
class SmallOps(DecodeMoeOps):
|
class SmallOps(DecodeMoeOps):
|
||||||
@@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps):
|
|||||||
shared_expert_rank_num)
|
shared_expert_rank_num)
|
||||||
self.tp_hcomm_info = ""
|
self.tp_hcomm_info = ""
|
||||||
|
|
||||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||||
|
x_active_mask):
|
||||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||||
x=x,
|
x=x,
|
||||||
expert_ids=expert_ids,
|
expert_ids=expert_ids,
|
||||||
expert_scales=expert_scales,
|
expert_scales=expert_scales,
|
||||||
|
x_active_mask=x_active_mask,
|
||||||
group_ep=self.ep_hcomm_info,
|
group_ep=self.ep_hcomm_info,
|
||||||
ep_world_size=self.ep_world_size,
|
ep_world_size=self.ep_world_size,
|
||||||
ep_rank_id=self.global_rank_id,
|
ep_rank_id=self.global_rank_id,
|
||||||
@@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps):
|
|||||||
assist_info_for_combine=assist_info_for_combine,
|
assist_info_for_combine=assist_info_for_combine,
|
||||||
ep_send_counts=ep_send_counts,
|
ep_send_counts=ep_send_counts,
|
||||||
expert_scales=expert_scales,
|
expert_scales=expert_scales,
|
||||||
|
x_active_mask=x_active_mask,
|
||||||
group_ep=self.ep_hcomm_info,
|
group_ep=self.ep_hcomm_info,
|
||||||
ep_world_size=self.ep_world_size,
|
ep_world_size=self.ep_world_size,
|
||||||
ep_rank_id=self.global_rank_id,
|
ep_rank_id=self.global_rank_id,
|
||||||
@@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps):
|
|||||||
ep_world_size, moe_expert_num, global_rank_id,
|
ep_world_size, moe_expert_num, global_rank_id,
|
||||||
shared_expert_rank_num)
|
shared_expert_rank_num)
|
||||||
|
|
||||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||||
|
x_active_mask):
|
||||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||||
x=x,
|
x=x,
|
||||||
expert_ids=expert_ids,
|
expert_ids=expert_ids,
|
||||||
@@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps):
|
|||||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||||
gmm2_weight=self.gmm2_weight,
|
gmm2_weight=self.gmm2_weight,
|
||||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||||
expert_smooth_scales=smooth_scales,
|
|
||||||
expert_scales=expert_scales,
|
expert_scales=expert_scales,
|
||||||
|
expert_smooth_scales=smooth_scales,
|
||||||
|
x_active_mask=x_active_mask,
|
||||||
group_ep=self.ep_hcomm_info,
|
group_ep=self.ep_hcomm_info,
|
||||||
ep_rank_size=self.ep_world_size,
|
ep_rank_size=self.ep_world_size,
|
||||||
ep_rank_id=self.global_rank_id,
|
ep_rank_id=self.global_rank_id,
|
||||||
@@ -267,12 +288,13 @@ def generate_datas(batch_size,
|
|||||||
shared_expert_rank_num=0,
|
shared_expert_rank_num=0,
|
||||||
top_k=8,
|
top_k=8,
|
||||||
test_bfloat16=True,
|
test_bfloat16=True,
|
||||||
enable_dynamic_bs=False):
|
enable_dynamic_bs=False,
|
||||||
|
with_mc2_mask=False):
|
||||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||||
shared_expert_rank_num)
|
shared_expert_rank_num)
|
||||||
actual_bs = int(
|
actual_bs = int(
|
||||||
torch.randint(1, batch_size, [1]).item(
|
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
|
||||||
) if enable_dynamic_bs else batch_size)
|
) if enable_dynamic_bs else batch_size)
|
||||||
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||||
gmm1_input_dim = token_hidden_size
|
gmm1_input_dim = token_hidden_size
|
||||||
@@ -317,9 +339,16 @@ def generate_datas(batch_size,
|
|||||||
else:
|
else:
|
||||||
x = x.half()
|
x = x.half()
|
||||||
smooth_sales = None
|
smooth_sales = None
|
||||||
return (x, expert_ids, smooth_sales, expert_scales), \
|
x_active_mask = None
|
||||||
|
valid_token_num = actual_bs
|
||||||
|
if with_mc2_mask:
|
||||||
|
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
|
||||||
|
x_active_mask = torch.cat(
|
||||||
|
(torch.ones(valid_token_num),
|
||||||
|
torch.zeros(actual_bs - valid_token_num))).bool()
|
||||||
|
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
|
||||||
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
||||||
actual_bs
|
actual_bs, valid_token_num
|
||||||
|
|
||||||
|
|
||||||
def run_once(local_rank_id,
|
def run_once(local_rank_id,
|
||||||
@@ -332,7 +361,8 @@ def run_once(local_rank_id,
|
|||||||
top_k=8,
|
top_k=8,
|
||||||
test_bfloat16=True,
|
test_bfloat16=True,
|
||||||
enable_dynamic_bs=False,
|
enable_dynamic_bs=False,
|
||||||
test_graph=False):
|
test_graph=False,
|
||||||
|
with_mc2_mask=False):
|
||||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||||
) if output_to_file(local_rank_id) else None
|
) if output_to_file(local_rank_id) else None
|
||||||
global_rank_id = local_rank_id # 单机
|
global_rank_id = local_rank_id # 单机
|
||||||
@@ -358,8 +388,8 @@ def run_once(local_rank_id,
|
|||||||
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||||
ep_world_size, moe_expert_num, global_rank_id,
|
ep_world_size, moe_expert_num, global_rank_id,
|
||||||
shared_expert_rank_num)
|
shared_expert_rank_num)
|
||||||
input_datas, weight_datas, actual_bs = generate_datas(
|
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
|
||||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
|
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
|
||||||
input_datas = [
|
input_datas = [
|
||||||
data.npu() if data is not None else None for data in input_datas
|
data.npu() if data is not None else None for data in input_datas
|
||||||
]
|
]
|
||||||
@@ -382,8 +412,8 @@ def run_once(local_rank_id,
|
|||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
log_file.close()
|
log_file.close()
|
||||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||||
torch.testing.assert_close(small_op_token_output.cpu(),
|
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||||
fused_op_token_output.cpu(),
|
fused_op_token_output[0:valid_token_num].cpu(),
|
||||||
atol=2.0,
|
atol=2.0,
|
||||||
rtol=0.02)
|
rtol=0.02)
|
||||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||||
@@ -394,18 +424,16 @@ def run_once(local_rank_id,
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test():
|
def test_dispatch_gmm_combine_decode_base():
|
||||||
batch_size = 64
|
custom_kwargs = BASE_KWARGS
|
||||||
token_hidden_size = 7168
|
ep_world_size = custom_kwargs["ep_world_size"]
|
||||||
moe_intermediate_size = 2048
|
custom_args = tuple(custom_kwargs.values())
|
||||||
ep_world_size = 16
|
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||||
moe_expert_num = 64
|
|
||||||
shared_expert_rank_num = 0
|
|
||||||
top_k = 8
|
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||||
test_bfloat16 = True
|
custom_kwargs = BASE_KWARGS
|
||||||
enable_dynamic_bs = False
|
custom_kwargs["with_mc2_mask"] = True
|
||||||
test_graph = False
|
ep_world_size = custom_kwargs["ep_world_size"]
|
||||||
args = (batch_size, token_hidden_size, moe_intermediate_size,
|
custom_args = tuple(custom_kwargs.values())
|
||||||
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
|
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||||
test_bfloat16, enable_dynamic_bs, test_graph)
|
|
||||||
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user