[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-19 16:31:48 +08:00
committed by GitHub
parent 636265be6d
commit 118b0ed346
14 changed files with 292 additions and 96 deletions

View File

@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight, const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale, const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional, const aclTensor *xActiveMaskOptional,
char *groupEp, char *groupEp,
int64_t epRankSize, int64_t epRankSize,
int64_t epRankId, int64_t epRankId,
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight, const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale, const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional, const aclTensor *xActiveMaskOptional,
char *groupEp, char *groupEp,
int64_t epRankSize, int64_t epRankSize,
int64_t epRankId, int64_t epRankId,
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
aclOpExecutor **executor) aclOpExecutor **executor)
{ {
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor); output, epRecvCount, workspaceSize, executor);
} }

View File

@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
const aclTensor *gmm1PermutedWeightScale, const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight, const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale, const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional, const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional, const aclTensor *xActiveMaskOptional,
char *groupEp, char *groupEp,
int64_t epRankSize, int64_t epRankSize,
int64_t epRankId, int64_t epRankId,

View File

@@ -45,14 +45,19 @@ public:
.DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND}) .Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales") this->Input("expert_smooth_scales")
.ParamType(OPTIONAL) .ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND}) .Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales") this->Input("x_active_mask")
.ParamType(OPTIONAL) .ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .DataType({ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND}) .Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output") this->Output("output")

View File

@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3; constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4; constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5; constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6; constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7; constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256; constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512; constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12; constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t ONE_DIMS = 1;
constexpr uint32_t TWO_DIMS = 2; constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512; constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168; constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h), OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED); return ge::GRAPH_FAILED);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
return ge::GRAPH_FAILED);
}
return ge::GRAPH_SUCCESS; return ge::GRAPH_SUCCESS;
} }
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum; tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."), OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED); return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp); SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) { const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
context->SetTilingKey(0); INPUT_SHARE_X_ACTIVE_MASK_INDEX);
} else { bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE); uint64_t tilingKey = 0;
if (xActiveMaskEnable) {
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
} }
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum); context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS; return ge::GRAPH_SUCCESS;
} }

View File

@@ -14,7 +14,8 @@
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
// input // input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output // output
GM_ADDR output, GM_ADDR outputRecvCount, GM_ADDR output, GM_ADDR outputRecvCount,
// system // system
@@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling); GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op; DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
op.Process(); op.Process();
} }
} }

View File

@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
@@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using GemmKernel = typename std::conditional< using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
@@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
gmexpertIds, gmexpertIds,
gmExpandIdx, gmExpandIdx,
gmEpSendCount, gmEpSendCount,
xActiveMask,
gmResvered, gmResvered,
gmOutputRecvCount, gmOutputRecvCount,
epRankSize, epRankSize,
@@ -241,7 +242,7 @@ public:
__aicore__ inline void Init( __aicore__ inline void Init(
// input // input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
// output // output
GM_ADDR output, GM_ADDR outputRecvCount, GM_ADDR output, GM_ADDR outputRecvCount,
// system // system
@@ -260,6 +261,7 @@ private:
GM_ADDR workspaceGM_; GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_; GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_; GM_ADDR gmexpertScales_;
GM_ADDR xActiveMask_;
uint32_t maxTokenNum_{0}; uint32_t maxTokenNum_{0};
uint32_t gmm1OutputDim_{0}; uint32_t gmm1OutputDim_{0};
@@ -291,7 +293,8 @@ template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init( __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
// input // input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output // output
GM_ADDR output, GM_ADDR outputRecvCount, GM_ADDR output, GM_ADDR outputRecvCount,
// system // system
@@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
gmOutputRecvCount_ = outputRecvCount; gmOutputRecvCount_ = outputRecvCount;
workspaceGM_ = workspaceGM; workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales; gmexpertScales_ = expert_scales;
xActiveMask_ = x_active_mask;
tilingData_ = tilingData; tilingData_ = tilingData;
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
@@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize); workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
if constexpr (EXEC_FLAG == 0) { if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
if constexpr (g_coreType == AscendC::AIV) { if constexpr (g_coreType == AscendC::AIV) {
AscendC::TPipe tpipe; AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false> MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
dispatcher; dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process(); dispatcher.Process();
tpipe.Destroy(); tpipe.Destroy();
@@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
Gmm1BlockScheduler>( Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>(); AscendC::PipeBarrier<PIPE_ALL>();
@@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner; MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
if (g_coreType == AscendC::AIV) { if (g_coreType == AscendC::AIV) {
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_, combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
workspaceGM_, nullptr, tilingData_); workspaceGM_, nullptr, tilingData_);
} }
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler, GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,

View File

@@ -353,7 +353,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
} }
template <typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_, template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
class ElementGroupList_> class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
{ {
@@ -411,6 +411,7 @@ public:
GM_ADDR gmX; GM_ADDR gmX;
GM_ADDR debugGm; GM_ADDR debugGm;
GM_ADDR gmexpertIds; GM_ADDR gmexpertIds;
GM_ADDR gmXActiveMask;
GM_ADDR gmExpandIdx; GM_ADDR gmExpandIdx;
GM_ADDR gmEpSendCount; GM_ADDR gmEpSendCount;
@@ -438,7 +439,7 @@ public:
LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_,
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
@@ -465,6 +466,7 @@ public:
gmExpandIdx(gmExpandIdx_), gmExpandIdx(gmExpandIdx_),
gmEpSendCount(gmEpSendCount_), gmEpSendCount(gmEpSendCount_),
gmOutputRecvCount(gmOutputRecvCount_), gmOutputRecvCount(gmOutputRecvCount_),
gmXActiveMask(gmXActiveMask_),
gmResvered(gmResvered_), gmResvered(gmResvered_),
epRankSize(epRankSize_), epRankSize(epRankSize_),
epRankId(epRankId_), epRankId(epRankId_),
@@ -635,6 +637,39 @@ public:
AscendC::SyncAll<false>(); AscendC::SyncAll<false>();
} }
CATLASS_DEVICE
void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset)
{
int64_t subUbOffset = ubOffset;
AscendC::LocalTensor<bool> maskInputTensor = (resource.ubBuf.template
GetBufferByByte<bool>(subUbOffset));
AscendC::LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.template ReinterpretCast<int8_t>();
subUbOffset += CEIL_UP(axisBS * sizeof(bool));
AscendC::LocalTensor<half> maskTmpTensor = (resource.ubBuf.template
GetBufferByByte<half>(subUbOffset));
subUbOffset += CEIL_UP(axisBS * sizeof(half));
AscendC::LocalTensor<half> sumOutTensor = (resource.ubBuf.template
GetBufferByByte<half>(subUbOffset));
subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE);
AscendC::GlobalTensor<bool> xActiveMaskGMTensor;
xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask);
uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool));
AscendC::DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS * sizeof(bool)), 0U, 0U, 0U};
AscendC::DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS);
AscendC::PipeBarrier<PIPE_V>();
AscendC::SumParams params{1, axisBsAlignSize, axisBS};
AscendC::Sum(sumOutTensor, maskTmpTensor, params);
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
activeMaskBsCnt = static_cast<int32_t>(sumOutTensor.GetValue(0));
}
CATLASS_DEVICE CATLASS_DEVICE
void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset)
{ {
@@ -778,8 +813,8 @@ public:
void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale)
{ {
uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; uint32_t newAivId = sendCoreIdx - sendToMoeAivNum;
uint32_t sendTokenNum = axisBS / sendToShareAivNum; uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum;
uint32_t remainderTokenNum = axisBS % sendToShareAivNum; uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum;
uint32_t startTokenId = sendTokenNum * newAivId; uint32_t startTokenId = sendTokenNum * newAivId;
if (newAivId < remainderTokenNum) { if (newAivId < remainderTokenNum) {
sendTokenNum += 1; sendTokenNum += 1;
@@ -788,7 +823,7 @@ public:
startTokenId += remainderTokenNum; startTokenId += remainderTokenNum;
} }
uint32_t endTokenId = startTokenId + sendTokenNum; uint32_t endTokenId = startTokenId + sendTokenNum;
if (startTokenId >= axisBS) { if (startTokenId >= activeMaskBsCnt) {
return; return;
} }
@@ -962,10 +997,13 @@ public:
} }
CATLASS_DEVICE void CATLASS_DEVICE void
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx) SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask)
{ {
ubOffset = 0; ubOffset = 0;
expertIdsCnt = axisBS * axisK; if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
TokenActiveMaskCal(gmXActiveMask, ubOffset);
}
expertIdsCnt = activeMaskBsCnt * axisK;
AscendC::GlobalTensor<int32_t> expertIdsGMTensor_; AscendC::GlobalTensor<int32_t> expertIdsGMTensor_;
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds);
@@ -1372,6 +1410,7 @@ public:
hCommuSize = hOutSize + scaleParamPad; hCommuSize = hOutSize + scaleParamPad;
axisHCommu = hCommuSize / sizeof(int8_t); axisHCommu = hCommuSize / sizeof(int8_t);
axisBS = params.bs; axisBS = params.bs;
activeMaskBsCnt = axisBS;
axisK = params.topK; axisK = params.topK;
uint32_t maxAxisBs = params.globalBs / epRankSize; uint32_t maxAxisBs = params.globalBs / epRankSize;
@@ -1489,7 +1528,7 @@ public:
AivInitState(); AivInitState();
if (isSendCore) { if (isSendCore) {
SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA,
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
} }
if (isRecvCore) { if (isRecvCore) {
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
@@ -1596,6 +1635,7 @@ private:
uint32_t hCommuSize{0}; uint32_t hCommuSize{0};
uint32_t axisHCommu{0}; uint32_t axisHCommu{0};
uint32_t axisBS{0}; uint32_t axisBS{0};
uint32_t activeMaskBsCnt{0};
uint32_t axisK{0}; uint32_t axisK{0};
uint32_t totalTokenCount{0}; uint32_t totalTokenCount{0};
uint32_t expertIdsCnt{0}; uint32_t expertIdsCnt{0};

View File

@@ -62,7 +62,7 @@ class CamMoeDistributeCombine
public: public:
__aicore__ inline CamMoeDistributeCombine(){}; __aicore__ inline CamMoeDistributeCombine(){};
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
const DispatchGmmCombineDecodeTilingData *tilingData); const DispatchGmmCombineDecodeTilingData *tilingData);
__aicore__ inline void Process(); __aicore__ inline void Process();
__aicore__ inline void AllToAllSend(); __aicore__ inline void AllToAllSend();
@@ -145,6 +145,7 @@ private:
GlobalTensor<ExpandIdxType> epSendCountGM_; GlobalTensor<ExpandIdxType> epSendCountGM_;
GlobalTensor<ExpandIdxType> tpSendCountGM_; GlobalTensor<ExpandIdxType> tpSendCountGM_;
GlobalTensor<float> expandScalesGM_; GlobalTensor<float> expandScalesGM_;
GlobalTensor<bool> xActiveMaskGM_;
GlobalTensor<ExpandXType> expandOutGlobal_; GlobalTensor<ExpandXType> expandOutGlobal_;
GlobalTensor<ExpandXType> rankWindow_; GlobalTensor<ExpandXType> rankWindow_;
GlobalTensor<int32_t> rankStates_; GlobalTensor<int32_t> rankStates_;
@@ -169,6 +170,8 @@ private:
CombineCalcInfo calcInfo_; CombineCalcInfo calcInfo_;
uint32_t axisBS_{0}; uint32_t axisBS_{0};
uint32_t axisMaxBs_{0}; uint32_t axisMaxBs_{0};
uint32_t axisBsAlignSize_{0};
uint64_t activeMaskBsCnt_{0};
uint32_t axisH_{0}; uint32_t axisH_{0};
uint32_t axisK_{0}; uint32_t axisK_{0};
uint32_t aivNum_{0}; uint32_t aivNum_{0};
@@ -224,6 +227,9 @@ private:
TBuf<> gatherMaskOutBuf_; // gather mask output buf TBuf<> gatherMaskOutBuf_; // gather mask output buf
TBuf<> gatherTmpBuf_; TBuf<> gatherTmpBuf_;
TBuf<> statusSumOutBuf_; TBuf<> statusSumOutBuf_;
TBuf<> xActMaskTBuf_;
TBuf<> xActMaskCastTBuf_;
TBuf<> xActMaskSumTBuf_;
float sumTarget_{0.0}; float sumTarget_{0.0};
int32_t epStateValue_; int32_t epStateValue_;
bool isShardExpert_{false}; bool isShardExpert_{false};
@@ -231,7 +237,7 @@ private:
template <TemplateMC2TypeClass> template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init( __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask,
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{ {
tpipe_ = pipe; tpipe_ = pipe;
@@ -264,8 +270,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount); epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales); expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
xActiveMaskGM_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
activeMaskBsCnt_ = axisBS_;
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
@@ -395,6 +403,27 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuf
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float)); tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(xActMaskTBuf_, axisBsAlignSize_);
tpipe_->InitBuffer(xActMaskCastTBuf_, axisBsAlignSize_ * sizeof(half));
tpipe_->InitBuffer(xActMaskSumTBuf_, axisBsAlignSize_ * sizeof(half));
LocalTensor<bool> xActiveMaskTensor = xActMaskTBuf_.Get<bool>();
LocalTensor<half> tempTensor = xActMaskCastTBuf_.Get<half>();
LocalTensor<half> sumOutTensor = xActMaskSumTBuf_.Get<half>();
DataCopyExtParams xActiveMaskParams{1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
DataCopyPadExtParams<bool> xActiveMaskCopyPadParams{false, 0U, 0U, 0U};
DataCopyPad(xActiveMaskTensor, xActiveMaskGM_, xActiveMaskParams, xActiveMaskCopyPadParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
LocalTensor<int8_t> xActiveMaskInt8Tensor = xActiveMaskTensor.ReinterpretCast<int8_t>();
Cast(tempTensor, xActiveMaskInt8Tensor, RoundMode::CAST_NONE, axisBS_);
PipeBarrier<PIPE_V>();
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, tempTensor, params);
SyncFunc<AscendC::HardEvent::V_S>();
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
}
} }
template <TemplateMC2TypeClass> template <TemplateMC2TypeClass>
@@ -645,13 +674,16 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatc
template <TemplateMC2TypeClass> template <TemplateMC2TypeClass>
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy() __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
{ {
if (activeMaskBsCnt_ == 0U) {
return;
}
uint32_t beginIndex = 0; uint32_t beginIndex = 0;
uint32_t endIndex = 0; uint32_t endIndex = 0;
uint32_t processLen = 0; uint32_t processLen = 0;
uint32_t tokenOffset = 0; uint32_t tokenOffset = 0;
if (axisBS_ < aivNum_) { if (activeMaskBsCnt_ < aivNum_) {
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_ uint32_t aivNumPerToken = aivNum_ / activeMaskBsCnt_; // activeMaskBsCnt_ < aivNum_
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) { if (coreIdx_ >= (activeMaskBsCnt_ * aivNumPerToken)) {
return; return;
} }
uint32_t tokenIndex = coreIdx_ / aivNumPerToken; uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
@@ -663,8 +695,8 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
beginIndex = tokenIndex; beginIndex = tokenIndex;
endIndex = beginIndex + 1U; endIndex = beginIndex + 1U;
} else { } else {
uint32_t tokenPerAivNum = axisBS_ / aivNum_; uint32_t tokenPerAivNum = activeMaskBsCnt_ / aivNum_;
uint32_t remainderToken = axisBS_ % aivNum_; uint32_t remainderToken = activeMaskBsCnt_ % aivNum_;
beginIndex = tokenPerAivNum * coreIdx_; beginIndex = tokenPerAivNum * coreIdx_;
if (coreIdx_ < remainderToken) { if (coreIdx_ < remainderToken) {
tokenPerAivNum++; tokenPerAivNum++;
@@ -723,10 +755,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
} }
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>(); LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
if (sharedExpertRankNum_ > 0U) { if (sharedExpertRankNum_ > 0U) {
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
epRankId_ * axisBS_ / sharedExpertRankNum_; epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
__gm__ ExpandXType *shareAddr = __gm__ ExpandXType *shareAddr =
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) + (__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
(tokenIndex - preCnt) * axisH_ + tokenOffset; (tokenIndex - preCnt) * axisH_ + tokenOffset;

View File

@@ -22,6 +22,7 @@ constexpr uint8_t BUFFER_NUM = 2;
constexpr uint32_t STATE_OFFSET = 512; // state space offset constexpr uint32_t STATE_OFFSET = 512; // state space offset
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
constexpr uint32_t UB_ALIGN = 32; constexpr uint32_t UB_ALIGN = 32;
constexpr uint64_t ALIGNED_LEN_256 = 256UL;
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
constexpr uint8_t COMM_NUM = 2; constexpr uint8_t COMM_NUM = 2;
constexpr uint8_t COMM_EP_IDX = 0; constexpr uint8_t COMM_EP_IDX = 0;
@@ -47,18 +48,13 @@ __aicore__ inline void SyncFunc()
AscendC::WaitFlag<event>(eventID); AscendC::WaitFlag<event>(eventID);
} }
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater
using namespace AscendC; using namespace AscendC;
template <TemplateDispatchTypeClass> template <TemplateDispatchTypeClass>
class CamMoeDistributeDispatch class CamMoeDistributeDispatch
{ {
public: public:
__aicore__ inline CamMoeDistributeDispatch(){}; __aicore__ inline CamMoeDistributeDispatch(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut,
GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut,
GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
@@ -80,6 +76,7 @@ private:
__aicore__ inline void AllGatherSetStatusAndWait(); __aicore__ inline void AllGatherSetStatusAndWait();
__aicore__ inline void ResetStatus(); __aicore__ inline void ResetStatus();
__aicore__ inline void QuantInit(GM_ADDR scales); __aicore__ inline void QuantInit(GM_ADDR scales);
__aicore__ inline void TokenActiveMaskCal();
__aicore__ inline void AllgatherProcessOut(); __aicore__ inline void AllgatherProcessOut();
__aicore__ inline void UpdataMultiMoeTokenNumsOut(); __aicore__ inline void UpdataMultiMoeTokenNumsOut();
__aicore__ inline void UpdataTokenNumsOut(); __aicore__ inline void UpdataTokenNumsOut();
@@ -112,6 +109,7 @@ private:
GlobalTensor<XType> xGMTensor_; GlobalTensor<XType> xGMTensor_;
GlobalTensor<int32_t> expertIdsGMTensor_; GlobalTensor<int32_t> expertIdsGMTensor_;
GlobalTensor<float> scalesGMTensor_; GlobalTensor<float> scalesGMTensor_;
GlobalTensor<bool> xActiveMaskGMTensor_;
GlobalTensor<ExpandXOutType> expandXOutGMTensor_; GlobalTensor<ExpandXOutType> expandXOutGMTensor_;
GlobalTensor<float> dynamicScalesOutGMTensor_; GlobalTensor<float> dynamicScalesOutGMTensor_;
GlobalTensor<int64_t> expertTokenNumsOutGMTensor_; GlobalTensor<int64_t> expertTokenNumsOutGMTensor_;
@@ -144,6 +142,9 @@ private:
TBuf<> rowMaxBuf_; TBuf<> rowMaxBuf_;
TBuf<> receiveDataCastFloatBuf_; TBuf<> receiveDataCastFloatBuf_;
TBuf<> smoothScalesBuf_; TBuf<> smoothScalesBuf_;
TBuf<> dstExpBuf_;
TBuf<> subExpBuf_;
TBuf<> gatherMaskTBuf_;
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_; TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_;
TQue<QuePosition::VECIN, 1> xInQueue_; TQue<QuePosition::VECIN, 1> xInQueue_;
TQue<QuePosition::VECOUT, 1> xOutQueue_; TQue<QuePosition::VECOUT, 1> xOutQueue_;
@@ -161,7 +162,10 @@ private:
GM_ADDR tpLocalStatusWindowGM_; GM_ADDR tpLocalStatusWindowGM_;
GlobalTensor<GM_ADDR> peerMemsAddrGm_; GlobalTensor<GM_ADDR> peerMemsAddrGm_;
uint32_t axisBS_{0}; uint32_t axisBS_{0};
uint32_t axisBsAlignSize_{0};
uint64_t activeMaskBsCnt_{0};
uint32_t axisMaxBS_{0}; uint32_t axisMaxBS_{0};
uint64_t sendToMoeExpTokenCnt_{0};
uint32_t axisH_{0}; uint32_t axisH_{0};
uint32_t axisK_{0}; uint32_t axisK_{0};
uint32_t aivNum_{0}; uint32_t aivNum_{0};
@@ -183,6 +187,7 @@ private:
uint32_t hSize_{0}; uint32_t hSize_{0};
uint32_t hOutSize_{0}; uint32_t hOutSize_{0};
uint32_t hCommuSize_{0}; uint32_t hCommuSize_{0};
uint32_t maxSize_{0};
uint32_t scaleParamPad_{0}; uint32_t scaleParamPad_{0};
uint32_t axisHCommu_{0}; uint32_t axisHCommu_{0};
uint32_t startExpertId_; uint32_t startExpertId_;
@@ -216,7 +221,7 @@ private:
template <TemplateDispatchTypeClass> template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init( __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
{ {
@@ -268,6 +273,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
tpWorldSize_ = 1; tpWorldSize_ = 1;
xGMTensor_.SetGlobalBuffer((__gm__ XType *)x); xGMTensor_.SetGlobalBuffer((__gm__ XType *)x);
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds);
xActiveMaskGMTensor_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut); expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut);
dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut); expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut);
@@ -329,7 +335,8 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
if (isQuant_) { if (isQuant_) {
QuantInit(scales); QuantInit(scales);
} }
uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; uint32_t expertIdsCnt = axisBS_ * axisK_;
uint32_t expertIdsSize = Ceil(expertIdsCnt * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize); tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize);
expertIdsTensor_ = expertIdsBuf_.Get<int32_t>(); expertIdsTensor_ = expertIdsBuf_.Get<int32_t>();
tpipe_->InitBuffer(expertCountBuf_, expertIdsSize); tpipe_->InitBuffer(expertCountBuf_, expertIdsSize);
@@ -339,7 +346,22 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
tpipe_->InitBuffer(getTotalBuf_, tpipe_->InitBuffer(getTotalBuf_,
epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t)); epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t));
tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2); tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2);
uint32_t hFp32Size = axisH_ * sizeof(float);
uint32_t xActivateMaskSize = axisBS_ * (Ceil(axisK_ * sizeof(bool), UB_ALIGN) * UB_ALIGN) * sizeof(half);
uint32_t bsAlign256 = Ceil(axisBS_ * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
uint32_t bsKAlign256 = Ceil(expertIdsCnt * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
expertIdsSize = Ceil(expertIdsSize, UB_ALIGN) * UB_ALIGN;
maxSize_ = hFp32Size > expertIdsSize ? hFp32Size : expertIdsSize;
maxSize_ = maxSize_ > xActivateMaskSize ? maxSize_ : xActivateMaskSize;
maxSize_ = maxSize_ > bsKAlign256 ? maxSize_ : bsKAlign256;
tpipe_->InitBuffer(gatherMaskTBuf_, maxSize_);
if constexpr (DynamicQuant || StaticQuant) {
dstExpBuf_ = receiveDataCastFloatBuf_; // 内存复用
subExpBuf_ = smoothScalesBuf_; // 内存复用
} else {
tpipe_->InitBuffer(dstExpBuf_, maxSize_); // BS * K * 4 = 32K
tpipe_->InitBuffer(subExpBuf_, maxSize_); // BS * K * 4 = 32K
}
moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK; moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK;
if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK && if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK &&
axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) { axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) {
@@ -370,11 +392,36 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Quant
dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>(); dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>();
} }
template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::TokenActiveMaskCal()
{
// 搬运x_active_mask, 当前仅用于计算有效token总数
LocalTensor<half> maskTmpTensor;
LocalTensor<half> sumOutTensor;
LocalTensor<bool> maskInputTensor;
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
maskInputTensor = dstExpBuf_.Get<bool>();
maskTmpTensor = subExpBuf_.Get<half>();
sumOutTensor = gatherMaskTBuf_.Get<half>();
DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
DataCopyPad(maskInputTensor, xActiveMaskGMTensor_, maskParams, maskCopyPadParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.ReinterpretCast<int8_t>();
Cast(maskTmpTensor, maskInputInt8Tensor, RoundMode::CAST_NONE, axisBS_);
PipeBarrier<PIPE_V>();
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, maskTmpTensor, params);
SyncFunc<AscendC::HardEvent::V_S>();
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
}
template <TemplateDispatchTypeClass> template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert() __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert()
{ {
uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_; uint32_t sendTokenNum = activeMaskBsCnt_ / sharedUsedAivNum_;
uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_; uint32_t remainderTokenNum = activeMaskBsCnt_ % sharedUsedAivNum_;
uint32_t newAivId = aivId_ - moeUsedAivNum_; uint32_t newAivId = aivId_ - moeUsedAivNum_;
uint32_t startTokenId = sendTokenNum * newAivId; uint32_t startTokenId = sendTokenNum * newAivId;
if (newAivId < remainderTokenNum) { if (newAivId < remainderTokenNum) {
@@ -383,16 +430,16 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
} else { } else {
startTokenId += remainderTokenNum; startTokenId += remainderTokenNum;
} }
if (startTokenId >= axisBS_) { if (startTokenId >= activeMaskBsCnt_) {
return; return;
} }
uint32_t endTokenId = startTokenId + sendTokenNum; uint32_t endTokenId = startTokenId + sendTokenNum;
for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) { for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) {
uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum); uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum);
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
epRankId_ * axisBS_ / sharedExpertRankNum_; epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
GlobalTensor<ExpandXOutType> dstWinGMTensor; GlobalTensor<ExpandXOutType> dstWinGMTensor;
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) +
expertPerSizeOnWin_ * epRankId_)); expertPerSizeOnWin_ * epRankId_));
@@ -440,7 +487,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
template <TemplateDispatchTypeClass> template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert() __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert()
{ {
uint32_t expertIdsCnt = axisBS_ * axisK_; uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_;
uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_;
uint32_t startTokenId = sendTokenNum * aivId_; uint32_t startTokenId = sendTokenNum * aivId_;
@@ -496,7 +543,12 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
template <TemplateDispatchTypeClass> template <TemplateDispatchTypeClass>
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch() __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch()
{ {
uint32_t expertIdsCnt = axisBS_ * axisK_; activeMaskBsCnt_ = axisBS_;
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
TokenActiveMaskCal();
}
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U};
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U}; DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams);

View File

@@ -14,5 +14,8 @@
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG #define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG #define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater, uint32_t EXEC_FLAG
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H #endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H

View File

@@ -70,5 +70,6 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t WORKSPACE_STAGES = 4;
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H

View File

@@ -640,8 +640,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale, const at::Tensor &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales, const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales, const c10::optional<at::Tensor> &x_active_mask,
c10::string_view group_ep, c10::string_view group_ep,
int64_t ep_rank_size, int64_t ep_rank_size,
int64_t ep_rank_id, int64_t ep_rank_id,
@@ -674,8 +675,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
gmm1_permuted_weight_scale, gmm1_permuted_weight_scale,
gmm2_weight, gmm2_weight,
gmm2_weight_scale, gmm2_weight_scale,
expert_smooth_scales,
expert_scales, expert_scales,
expert_smooth_scales,
x_active_mask,
//input attrs //input attrs
group_ep_ptr, group_ep_ptr,
ep_rank_size, ep_rank_size,
@@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
" Tensor gmm1_permuted_weight_scale," " Tensor gmm1_permuted_weight_scale,"
" Tensor gmm2_weight, Tensor gmm2_weight_scale," " Tensor gmm2_weight, Tensor gmm2_weight_scale,"
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None," " Tensor expert_scales, Tensor? expert_smooth_scales=None,"
" Tensor? x_active_mask=None,"
" str group_ep=''," " str group_ep='',"
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0," " int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
" int shared_expert_num=1, int shared_expert_rank_num=0," " int shared_expert_num=1, int shared_expert_rank_num=0,"

View File

@@ -161,8 +161,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm1_permuted_weight_scale,
const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight,
const at::Tensor &gmm2_weight_scale, const at::Tensor &gmm2_weight_scale,
const at::Tensor &expert_scales,
const c10::optional<at::Tensor> &expert_smooth_scales, const c10::optional<at::Tensor> &expert_smooth_scales,
const c10::optional<at::Tensor> &expert_scales, const c10::optional<at::Tensor> &x_active_mask,
c10::string_view group_ep, c10::string_view group_ep,
int64_t ep_rank_size, int64_t ep_rank_size,
int64_t ep_rank_id, int64_t ep_rank_id,

View File

@@ -16,6 +16,19 @@ torch.manual_seed(42)
torch_npu.npu.config.allow_internal_format = True torch_npu.npu.config.allow_internal_format = True
enable_custom_op() enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs" LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
BASE_KWARGS = {
"batch_size": 64,
"token_hidden_size": 7168,
"moe_intermediate_size": 2048,
"ep_world_size": 16,
"moe_expert_num": 64,
"shared_expert_rank_num": 0,
"top_k": 8,
"test_bfloat16": True,
"enable_dynamic_bs": False,
"test_graph": False,
"with_mc2_mask": False
}
def redirect_output(log_file_path): def redirect_output(log_file_path):
@@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module):
self.gmm2_weight_scale_fp32 = torch.nn.Parameter( self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
gmm2_weight_scale.float(), requires_grad=False) gmm2_weight_scale.float(), requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError("To be implemented in subclass")
def forward(self, x, expert_ids, smooth_scales, expert_scales): def forward(self, x, expert_ids, smooth_scales, expert_scales,
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales) x_active_mask):
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
x_active_mask)
class SmallOps(DecodeMoeOps): class SmallOps(DecodeMoeOps):
@@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num) shared_expert_rank_num)
self.tp_hcomm_info = "" self.tp_hcomm_info = ""
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
outputs = torch_npu.npu_moe_distribute_dispatch_v2( outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x, x=x,
expert_ids=expert_ids, expert_ids=expert_ids,
expert_scales=expert_scales, expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info, group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size, ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id, ep_rank_id=self.global_rank_id,
@@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps):
assist_info_for_combine=assist_info_for_combine, assist_info_for_combine=assist_info_for_combine,
ep_send_counts=ep_send_counts, ep_send_counts=ep_send_counts,
expert_scales=expert_scales, expert_scales=expert_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info, group_ep=self.ep_hcomm_info,
ep_world_size=self.ep_world_size, ep_world_size=self.ep_world_size,
ep_rank_id=self.global_rank_id, ep_rank_id=self.global_rank_id,
@@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps):
ep_world_size, moe_expert_num, global_rank_id, ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num) shared_expert_rank_num)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
x_active_mask):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode( output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x, x=x,
expert_ids=expert_ids, expert_ids=expert_ids,
@@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps):
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32, gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
gmm2_weight=self.gmm2_weight, gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale_fp32, gmm2_weight_scale=self.gmm2_weight_scale_fp32,
expert_smooth_scales=smooth_scales,
expert_scales=expert_scales, expert_scales=expert_scales,
expert_smooth_scales=smooth_scales,
x_active_mask=x_active_mask,
group_ep=self.ep_hcomm_info, group_ep=self.ep_hcomm_info,
ep_rank_size=self.ep_world_size, ep_rank_size=self.ep_world_size,
ep_rank_id=self.global_rank_id, ep_rank_id=self.global_rank_id,
@@ -267,12 +288,13 @@ def generate_datas(batch_size,
shared_expert_rank_num=0, shared_expert_rank_num=0,
top_k=8, top_k=8,
test_bfloat16=True, test_bfloat16=True,
enable_dynamic_bs=False): enable_dynamic_bs=False,
with_mc2_mask=False):
is_shared_expert = global_rank_id < shared_expert_rank_num is_shared_expert = global_rank_id < shared_expert_rank_num
moe_expert_num_per_rank = moe_expert_num // (ep_world_size - moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
shared_expert_rank_num) shared_expert_rank_num)
actual_bs = int( actual_bs = int(
torch.randint(1, batch_size, [1]).item( torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
) if enable_dynamic_bs else batch_size) ) if enable_dynamic_bs else batch_size)
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
gmm1_input_dim = token_hidden_size gmm1_input_dim = token_hidden_size
@@ -317,9 +339,16 @@ def generate_datas(batch_size,
else: else:
x = x.half() x = x.half()
smooth_sales = None smooth_sales = None
return (x, expert_ids, smooth_sales, expert_scales), \ x_active_mask = None
valid_token_num = actual_bs
if with_mc2_mask:
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
x_active_mask = torch.cat(
(torch.ones(valid_token_num),
torch.zeros(actual_bs - valid_token_num))).bool()
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \ (gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
actual_bs actual_bs, valid_token_num
def run_once(local_rank_id, def run_once(local_rank_id,
@@ -332,7 +361,8 @@ def run_once(local_rank_id,
top_k=8, top_k=8,
test_bfloat16=True, test_bfloat16=True,
enable_dynamic_bs=False, enable_dynamic_bs=False,
test_graph=False): test_graph=False,
with_mc2_mask=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log" log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None ) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机 global_rank_id = local_rank_id # 单机
@@ -358,8 +388,8 @@ def run_once(local_rank_id,
parameter = (batch_size, token_hidden_size, moe_intermediate_size, parameter = (batch_size, token_hidden_size, moe_intermediate_size,
ep_world_size, moe_expert_num, global_rank_id, ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num) shared_expert_rank_num)
input_datas, weight_datas, actual_bs = generate_datas( input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
*parameter, top_k, test_bfloat16, enable_dynamic_bs) *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
input_datas = [ input_datas = [
data.npu() if data is not None else None for data in input_datas data.npu() if data is not None else None for data in input_datas
] ]
@@ -382,8 +412,8 @@ def run_once(local_rank_id,
if log_file is not None: if log_file is not None:
log_file.close() log_file.close()
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output) small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
torch.testing.assert_close(small_op_token_output.cpu(), torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
fused_op_token_output.cpu(), fused_op_token_output[0:valid_token_num].cpu(),
atol=2.0, atol=2.0,
rtol=0.02) rtol=0.02)
torch.testing.assert_close(small_op_count_output.cpu(), torch.testing.assert_close(small_op_count_output.cpu(),
@@ -394,18 +424,16 @@ def run_once(local_rank_id,
@torch.inference_mode() @torch.inference_mode()
def test(): def test_dispatch_gmm_combine_decode_base():
batch_size = 64 custom_kwargs = BASE_KWARGS
token_hidden_size = 7168 ep_world_size = custom_kwargs["ep_world_size"]
moe_intermediate_size = 2048 custom_args = tuple(custom_kwargs.values())
ep_world_size = 16 mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
moe_expert_num = 64
shared_expert_rank_num = 0
top_k = 8 def test_dispatch_gmm_combine_decode_with_mc2_mask():
test_bfloat16 = True custom_kwargs = BASE_KWARGS
enable_dynamic_bs = False custom_kwargs["with_mc2_mask"] = True
test_graph = False ep_world_size = custom_kwargs["ep_world_size"]
args = (batch_size, token_hidden_size, moe_intermediate_size, custom_args = tuple(custom_kwargs.values())
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k, mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
test_bfloat16, enable_dynamic_bs, test_graph)
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)