[Feature] Add token mask for DispatchGmmCombineDecode operator (#5171)
### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
@@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
@@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
|
||||
aclOpExecutor **executor)
|
||||
{
|
||||
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
|
||||
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
|
||||
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
|
||||
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
|
||||
output, epRecvCount, workspaceSize, executor);
|
||||
}
|
||||
|
||||
@@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
|
||||
const aclTensor *gmm1PermutedWeightScale,
|
||||
const aclTensor *gmm2Weight,
|
||||
const aclTensor *gmm2WeightScale,
|
||||
const aclTensor *expertScales,
|
||||
const aclTensor *expertSmoothScalesOptional,
|
||||
const aclTensor *expertScalesOptional,
|
||||
const aclTensor *xActiveMaskOptional,
|
||||
char *groupEp,
|
||||
int64_t epRankSize,
|
||||
int64_t epRankId,
|
||||
|
||||
@@ -45,14 +45,19 @@ public:
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_smooth_scales")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("expert_scales")
|
||||
this->Input("x_active_mask")
|
||||
.ParamType(OPTIONAL)
|
||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
||||
.DataType({ge::DT_BOOL, ge::DT_BOOL})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("output")
|
||||
|
||||
@@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
|
||||
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
|
||||
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
|
||||
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
|
||||
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
|
||||
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;
|
||||
|
||||
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
|
||||
@@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
|
||||
constexpr uint32_t MAX_BATCH_SIZE = 256;
|
||||
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
|
||||
constexpr uint32_t SUPPORT_TOP_K = 12;
|
||||
constexpr uint32_t ONE_DIMS = 1;
|
||||
constexpr uint32_t TWO_DIMS = 2;
|
||||
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
|
||||
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
|
||||
@@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
||||
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
|
||||
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
|
||||
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
|
||||
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
|
||||
|
||||
@@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
|
||||
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||
if (xActiveMaskStorageShape != nullptr) {
|
||||
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
|
||||
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
|
||||
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
|
||||
return ge::GRAPH_FAILED);
|
||||
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
|
||||
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
|
||||
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
|
||||
return ge::GRAPH_FAILED);
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
|
||||
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
|
||||
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
|
||||
return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
|
||||
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
|
||||
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
|
||||
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
|
||||
SetHcommCfg(context, tilingData, groupEp);
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
|
||||
context->SetTilingKey(0);
|
||||
} else {
|
||||
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
|
||||
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
|
||||
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
|
||||
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
|
||||
uint64_t tilingKey = 0;
|
||||
if (xActiveMaskEnable) {
|
||||
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
|
||||
}
|
||||
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
|
||||
tilingKey |= EXEC_FLAG_DEEP_FUSE;
|
||||
}
|
||||
context->SetTilingKey(tilingKey);
|
||||
context->SetBlockDim(aicNum);
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||
GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
@@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
||||
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
|
||||
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
|
||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
|
||||
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
||||
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
||||
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
||||
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
|
||||
op.Process();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
|
||||
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
|
||||
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
|
||||
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
||||
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
|
||||
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
|
||||
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
|
||||
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
|
||||
@@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
using GemmKernel = typename std::conditional<
|
||||
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
||||
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
||||
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
||||
|
||||
@@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
||||
gmexpertIds,
|
||||
gmExpandIdx,
|
||||
gmEpSendCount,
|
||||
xActiveMask,
|
||||
gmResvered,
|
||||
gmOutputRecvCount,
|
||||
epRankSize,
|
||||
@@ -241,7 +242,7 @@ public:
|
||||
__aicore__ inline void Init(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
@@ -260,6 +261,7 @@ private:
|
||||
GM_ADDR workspaceGM_;
|
||||
GM_ADDR gmSmoothScales_;
|
||||
GM_ADDR gmexpertScales_;
|
||||
GM_ADDR xActiveMask_;
|
||||
|
||||
uint32_t maxTokenNum_{0};
|
||||
uint32_t gmm1OutputDim_{0};
|
||||
@@ -291,7 +293,8 @@ template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
// input
|
||||
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
|
||||
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
|
||||
GM_ADDR x_active_mask,
|
||||
// output
|
||||
GM_ADDR output, GM_ADDR outputRecvCount,
|
||||
// system
|
||||
@@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
|
||||
gmOutputRecvCount_ = outputRecvCount;
|
||||
workspaceGM_ = workspaceGM;
|
||||
gmexpertScales_ = expert_scales;
|
||||
xActiveMask_ = x_active_mask;
|
||||
tilingData_ = tilingData;
|
||||
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
|
||||
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
|
||||
@@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
|
||||
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);
|
||||
|
||||
if constexpr (EXEC_FLAG == 0) {
|
||||
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
|
||||
if constexpr (g_coreType == AscendC::AIV) {
|
||||
AscendC::TPipe tpipe;
|
||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
|
||||
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
|
||||
dispatcher;
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
|
||||
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
|
||||
dispatcher.Process();
|
||||
tpipe.Destroy();
|
||||
@@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
Gmm1BlockScheduler>(
|
||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
||||
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
|
||||
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
|
||||
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
|
||||
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
|
||||
AscendC::PipeBarrier<PIPE_ALL>();
|
||||
@@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
||||
|
||||
MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
|
||||
if (g_coreType == AscendC::AIV) {
|
||||
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
|
||||
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
|
||||
workspaceGM_, nullptr, tilingData_);
|
||||
}
|
||||
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
|
||||
|
||||
@@ -353,7 +353,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
|
||||
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
|
||||
}
|
||||
|
||||
template <typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
||||
template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
||||
class ElementGroupList_>
|
||||
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
|
||||
{
|
||||
@@ -411,6 +411,7 @@ public:
|
||||
GM_ADDR gmX;
|
||||
GM_ADDR debugGm;
|
||||
GM_ADDR gmexpertIds;
|
||||
GM_ADDR gmXActiveMask;
|
||||
|
||||
GM_ADDR gmExpandIdx;
|
||||
GM_ADDR gmEpSendCount;
|
||||
@@ -438,7 +439,7 @@ public:
|
||||
LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_,
|
||||
LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_,
|
||||
GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_,
|
||||
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_,
|
||||
GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_,
|
||||
GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_,
|
||||
uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_,
|
||||
uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_,
|
||||
@@ -465,6 +466,7 @@ public:
|
||||
gmExpandIdx(gmExpandIdx_),
|
||||
gmEpSendCount(gmEpSendCount_),
|
||||
gmOutputRecvCount(gmOutputRecvCount_),
|
||||
gmXActiveMask(gmXActiveMask_),
|
||||
gmResvered(gmResvered_),
|
||||
epRankSize(epRankSize_),
|
||||
epRankId(epRankId_),
|
||||
@@ -635,6 +637,39 @@ public:
|
||||
AscendC::SyncAll<false>();
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset)
|
||||
{
|
||||
int64_t subUbOffset = ubOffset;
|
||||
AscendC::LocalTensor<bool> maskInputTensor = (resource.ubBuf.template
|
||||
GetBufferByByte<bool>(subUbOffset));
|
||||
AscendC::LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.template ReinterpretCast<int8_t>();
|
||||
subUbOffset += CEIL_UP(axisBS * sizeof(bool));
|
||||
AscendC::LocalTensor<half> maskTmpTensor = (resource.ubBuf.template
|
||||
GetBufferByByte<half>(subUbOffset));
|
||||
subUbOffset += CEIL_UP(axisBS * sizeof(half));
|
||||
AscendC::LocalTensor<half> sumOutTensor = (resource.ubBuf.template
|
||||
GetBufferByByte<half>(subUbOffset));
|
||||
subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE);
|
||||
|
||||
AscendC::GlobalTensor<bool> xActiveMaskGMTensor;
|
||||
xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask);
|
||||
uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool));
|
||||
|
||||
AscendC::DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS * sizeof(bool)), 0U, 0U, 0U};
|
||||
AscendC::DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
|
||||
AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams);
|
||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(0);
|
||||
AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
AscendC::SumParams params{1, axisBsAlignSize, axisBS};
|
||||
AscendC::Sum(sumOutTensor, maskTmpTensor, params);
|
||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
||||
activeMaskBsCnt = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||
}
|
||||
|
||||
CATLASS_DEVICE
|
||||
void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset)
|
||||
{
|
||||
@@ -778,8 +813,8 @@ public:
|
||||
void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale)
|
||||
{
|
||||
uint32_t newAivId = sendCoreIdx - sendToMoeAivNum;
|
||||
uint32_t sendTokenNum = axisBS / sendToShareAivNum;
|
||||
uint32_t remainderTokenNum = axisBS % sendToShareAivNum;
|
||||
uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum;
|
||||
uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum;
|
||||
uint32_t startTokenId = sendTokenNum * newAivId;
|
||||
if (newAivId < remainderTokenNum) {
|
||||
sendTokenNum += 1;
|
||||
@@ -788,7 +823,7 @@ public:
|
||||
startTokenId += remainderTokenNum;
|
||||
}
|
||||
uint32_t endTokenId = startTokenId + sendTokenNum;
|
||||
if (startTokenId >= axisBS) {
|
||||
if (startTokenId >= activeMaskBsCnt) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -962,10 +997,13 @@ public:
|
||||
}
|
||||
|
||||
CATLASS_DEVICE void
|
||||
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx)
|
||||
SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask)
|
||||
{
|
||||
ubOffset = 0;
|
||||
expertIdsCnt = axisBS * axisK;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||
TokenActiveMaskCal(gmXActiveMask, ubOffset);
|
||||
}
|
||||
expertIdsCnt = activeMaskBsCnt * axisK;
|
||||
|
||||
AscendC::GlobalTensor<int32_t> expertIdsGMTensor_;
|
||||
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds);
|
||||
@@ -1372,6 +1410,7 @@ public:
|
||||
hCommuSize = hOutSize + scaleParamPad;
|
||||
axisHCommu = hCommuSize / sizeof(int8_t);
|
||||
axisBS = params.bs;
|
||||
activeMaskBsCnt = axisBS;
|
||||
axisK = params.topK;
|
||||
uint32_t maxAxisBs = params.globalBs / epRankSize;
|
||||
|
||||
@@ -1489,7 +1528,7 @@ public:
|
||||
AivInitState();
|
||||
if (isSendCore) {
|
||||
SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA,
|
||||
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx);
|
||||
(GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask);
|
||||
}
|
||||
if (isRecvCore) {
|
||||
RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount,
|
||||
@@ -1596,6 +1635,7 @@ private:
|
||||
uint32_t hCommuSize{0};
|
||||
uint32_t axisHCommu{0};
|
||||
uint32_t axisBS{0};
|
||||
uint32_t activeMaskBsCnt{0};
|
||||
uint32_t axisK{0};
|
||||
uint32_t totalTokenCount{0};
|
||||
uint32_t expertIdsCnt{0};
|
||||
|
||||
@@ -62,7 +62,7 @@ class CamMoeDistributeCombine
|
||||
public:
|
||||
__aicore__ inline CamMoeDistributeCombine(){};
|
||||
__aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount,
|
||||
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
||||
GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe,
|
||||
const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||
__aicore__ inline void Process();
|
||||
__aicore__ inline void AllToAllSend();
|
||||
@@ -145,6 +145,7 @@ private:
|
||||
GlobalTensor<ExpandIdxType> epSendCountGM_;
|
||||
GlobalTensor<ExpandIdxType> tpSendCountGM_;
|
||||
GlobalTensor<float> expandScalesGM_;
|
||||
GlobalTensor<bool> xActiveMaskGM_;
|
||||
GlobalTensor<ExpandXType> expandOutGlobal_;
|
||||
GlobalTensor<ExpandXType> rankWindow_;
|
||||
GlobalTensor<int32_t> rankStates_;
|
||||
@@ -169,6 +170,8 @@ private:
|
||||
CombineCalcInfo calcInfo_;
|
||||
uint32_t axisBS_{0};
|
||||
uint32_t axisMaxBs_{0};
|
||||
uint32_t axisBsAlignSize_{0};
|
||||
uint64_t activeMaskBsCnt_{0};
|
||||
uint32_t axisH_{0};
|
||||
uint32_t axisK_{0};
|
||||
uint32_t aivNum_{0};
|
||||
@@ -224,6 +227,9 @@ private:
|
||||
TBuf<> gatherMaskOutBuf_; // gather mask output buf
|
||||
TBuf<> gatherTmpBuf_;
|
||||
TBuf<> statusSumOutBuf_;
|
||||
TBuf<> xActMaskTBuf_;
|
||||
TBuf<> xActMaskCastTBuf_;
|
||||
TBuf<> xActMaskSumTBuf_;
|
||||
float sumTarget_{0.0};
|
||||
int32_t epStateValue_;
|
||||
bool isShardExpert_{false};
|
||||
@@ -231,7 +237,7 @@ private:
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
||||
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales,
|
||||
GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR xActiveMask,
|
||||
GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
tpipe_ = pipe;
|
||||
@@ -264,8 +270,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
|
||||
expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx);
|
||||
epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount);
|
||||
expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales);
|
||||
xActiveMaskGM_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
|
||||
expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut);
|
||||
axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
|
||||
activeMaskBsCnt_ = axisBS_;
|
||||
axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h;
|
||||
axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||
@@ -395,6 +403,27 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuf
|
||||
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
|
||||
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t));
|
||||
tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float));
|
||||
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(xActMaskTBuf_, axisBsAlignSize_);
|
||||
tpipe_->InitBuffer(xActMaskCastTBuf_, axisBsAlignSize_ * sizeof(half));
|
||||
tpipe_->InitBuffer(xActMaskSumTBuf_, axisBsAlignSize_ * sizeof(half));
|
||||
LocalTensor<bool> xActiveMaskTensor = xActMaskTBuf_.Get<bool>();
|
||||
LocalTensor<half> tempTensor = xActMaskCastTBuf_.Get<half>();
|
||||
LocalTensor<half> sumOutTensor = xActMaskSumTBuf_.Get<half>();
|
||||
DataCopyExtParams xActiveMaskParams{1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<bool> xActiveMaskCopyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(xActiveMaskTensor, xActiveMaskGM_, xActiveMaskParams, xActiveMaskCopyPadParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
LocalTensor<int8_t> xActiveMaskInt8Tensor = xActiveMaskTensor.ReinterpretCast<int8_t>();
|
||||
Cast(tempTensor, xActiveMaskInt8Tensor, RoundMode::CAST_NONE, axisBS_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SumParams params{1, axisBsAlignSize_, axisBS_};
|
||||
Sum(sumOutTensor, tempTensor, params);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <TemplateMC2TypeClass>
|
||||
@@ -645,13 +674,16 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::WaitDispatc
|
||||
template <TemplateMC2TypeClass>
|
||||
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindowCopy()
|
||||
{
|
||||
if (activeMaskBsCnt_ == 0U) {
|
||||
return;
|
||||
}
|
||||
uint32_t beginIndex = 0;
|
||||
uint32_t endIndex = 0;
|
||||
uint32_t processLen = 0;
|
||||
uint32_t tokenOffset = 0;
|
||||
if (axisBS_ < aivNum_) {
|
||||
uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_
|
||||
if (coreIdx_ >= (axisBS_ * aivNumPerToken)) {
|
||||
if (activeMaskBsCnt_ < aivNum_) {
|
||||
uint32_t aivNumPerToken = aivNum_ / activeMaskBsCnt_; // activeMaskBsCnt_ < aivNum_
|
||||
if (coreIdx_ >= (activeMaskBsCnt_ * aivNumPerToken)) {
|
||||
return;
|
||||
}
|
||||
uint32_t tokenIndex = coreIdx_ / aivNumPerToken;
|
||||
@@ -663,8 +695,8 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
|
||||
beginIndex = tokenIndex;
|
||||
endIndex = beginIndex + 1U;
|
||||
} else {
|
||||
uint32_t tokenPerAivNum = axisBS_ / aivNum_;
|
||||
uint32_t remainderToken = axisBS_ % aivNum_;
|
||||
uint32_t tokenPerAivNum = activeMaskBsCnt_ / aivNum_;
|
||||
uint32_t remainderToken = activeMaskBsCnt_ % aivNum_;
|
||||
beginIndex = tokenPerAivNum * coreIdx_;
|
||||
if (coreIdx_ < remainderToken) {
|
||||
tokenPerAivNum++;
|
||||
@@ -723,10 +755,10 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::LocalWindow
|
||||
}
|
||||
LocalTensor<ExpandXType> rowTmpLocal = tokenBuf_.Get<ExpandXType>();
|
||||
if (sharedExpertRankNum_ > 0U) {
|
||||
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
|
||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
|
||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
|
||||
epRankId_ * axisBS_ / sharedExpertRankNum_;
|
||||
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
|
||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
|
||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
|
||||
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
|
||||
__gm__ ExpandXType *shareAddr =
|
||||
(__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) +
|
||||
(tokenIndex - preCnt) * axisH_ + tokenOffset;
|
||||
|
||||
@@ -22,6 +22,7 @@ constexpr uint8_t BUFFER_NUM = 2;
|
||||
constexpr uint32_t STATE_OFFSET = 512; // state space offset
|
||||
constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M
|
||||
constexpr uint32_t UB_ALIGN = 32;
|
||||
constexpr uint64_t ALIGNED_LEN_256 = 256UL;
|
||||
constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024;
|
||||
constexpr uint8_t COMM_NUM = 2;
|
||||
constexpr uint8_t COMM_EP_IDX = 0;
|
||||
@@ -47,18 +48,13 @@ __aicore__ inline void SyncFunc()
|
||||
AscendC::WaitFlag<event>(eventID);
|
||||
}
|
||||
|
||||
#define TemplateDispatchTypeClass \
|
||||
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
||||
bool IsNeedAllgater
|
||||
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater
|
||||
|
||||
using namespace AscendC;
|
||||
template <TemplateDispatchTypeClass>
|
||||
class CamMoeDistributeDispatch
|
||||
{
|
||||
public:
|
||||
__aicore__ inline CamMoeDistributeDispatch(){};
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut,
|
||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut,
|
||||
GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut,
|
||||
GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
||||
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData);
|
||||
@@ -80,6 +76,7 @@ private:
|
||||
__aicore__ inline void AllGatherSetStatusAndWait();
|
||||
__aicore__ inline void ResetStatus();
|
||||
__aicore__ inline void QuantInit(GM_ADDR scales);
|
||||
__aicore__ inline void TokenActiveMaskCal();
|
||||
__aicore__ inline void AllgatherProcessOut();
|
||||
__aicore__ inline void UpdataMultiMoeTokenNumsOut();
|
||||
__aicore__ inline void UpdataTokenNumsOut();
|
||||
@@ -112,6 +109,7 @@ private:
|
||||
GlobalTensor<XType> xGMTensor_;
|
||||
GlobalTensor<int32_t> expertIdsGMTensor_;
|
||||
GlobalTensor<float> scalesGMTensor_;
|
||||
GlobalTensor<bool> xActiveMaskGMTensor_;
|
||||
GlobalTensor<ExpandXOutType> expandXOutGMTensor_;
|
||||
GlobalTensor<float> dynamicScalesOutGMTensor_;
|
||||
GlobalTensor<int64_t> expertTokenNumsOutGMTensor_;
|
||||
@@ -144,6 +142,9 @@ private:
|
||||
TBuf<> rowMaxBuf_;
|
||||
TBuf<> receiveDataCastFloatBuf_;
|
||||
TBuf<> smoothScalesBuf_;
|
||||
TBuf<> dstExpBuf_;
|
||||
TBuf<> subExpBuf_;
|
||||
TBuf<> gatherMaskTBuf_;
|
||||
TQueBind<QuePosition::VECIN, QuePosition::VECOUT, 1> xQueue_;
|
||||
TQue<QuePosition::VECIN, 1> xInQueue_;
|
||||
TQue<QuePosition::VECOUT, 1> xOutQueue_;
|
||||
@@ -161,7 +162,10 @@ private:
|
||||
GM_ADDR tpLocalStatusWindowGM_;
|
||||
GlobalTensor<GM_ADDR> peerMemsAddrGm_;
|
||||
uint32_t axisBS_{0};
|
||||
uint32_t axisBsAlignSize_{0};
|
||||
uint64_t activeMaskBsCnt_{0};
|
||||
uint32_t axisMaxBS_{0};
|
||||
uint64_t sendToMoeExpTokenCnt_{0};
|
||||
uint32_t axisH_{0};
|
||||
uint32_t axisK_{0};
|
||||
uint32_t aivNum_{0};
|
||||
@@ -183,6 +187,7 @@ private:
|
||||
uint32_t hSize_{0};
|
||||
uint32_t hOutSize_{0};
|
||||
uint32_t hCommuSize_{0};
|
||||
uint32_t maxSize_{0};
|
||||
uint32_t scaleParamPad_{0};
|
||||
uint32_t axisHCommu_{0};
|
||||
uint32_t startExpertId_;
|
||||
@@ -216,7 +221,7 @@ private:
|
||||
|
||||
template <TemplateDispatchTypeClass>
|
||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
||||
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
||||
GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut,
|
||||
GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut,
|
||||
GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData)
|
||||
{
|
||||
@@ -268,6 +273,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
||||
tpWorldSize_ = 1;
|
||||
xGMTensor_.SetGlobalBuffer((__gm__ XType *)x);
|
||||
expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds);
|
||||
xActiveMaskGMTensor_.SetGlobalBuffer((__gm__ bool*)xActiveMask);
|
||||
expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut);
|
||||
dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut);
|
||||
expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut);
|
||||
@@ -329,7 +335,8 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
||||
if (isQuant_) {
|
||||
QuantInit(scales);
|
||||
}
|
||||
uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
||||
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
||||
uint32_t expertIdsSize = Ceil(expertIdsCnt * sizeof(int32_t), UB_ALIGN) * UB_ALIGN;
|
||||
tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize);
|
||||
expertIdsTensor_ = expertIdsBuf_.Get<int32_t>();
|
||||
tpipe_->InitBuffer(expertCountBuf_, expertIdsSize);
|
||||
@@ -339,7 +346,22 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Init(
|
||||
tpipe_->InitBuffer(getTotalBuf_,
|
||||
epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t));
|
||||
tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2);
|
||||
|
||||
uint32_t hFp32Size = axisH_ * sizeof(float);
|
||||
uint32_t xActivateMaskSize = axisBS_ * (Ceil(axisK_ * sizeof(bool), UB_ALIGN) * UB_ALIGN) * sizeof(half);
|
||||
uint32_t bsAlign256 = Ceil(axisBS_ * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
|
||||
uint32_t bsKAlign256 = Ceil(expertIdsCnt * sizeof(half), ALIGNED_LEN_256) * ALIGNED_LEN_256 / sizeof(half);
|
||||
expertIdsSize = Ceil(expertIdsSize, UB_ALIGN) * UB_ALIGN;
|
||||
maxSize_ = hFp32Size > expertIdsSize ? hFp32Size : expertIdsSize;
|
||||
maxSize_ = maxSize_ > xActivateMaskSize ? maxSize_ : xActivateMaskSize;
|
||||
maxSize_ = maxSize_ > bsKAlign256 ? maxSize_ : bsKAlign256;
|
||||
tpipe_->InitBuffer(gatherMaskTBuf_, maxSize_);
|
||||
if constexpr (DynamicQuant || StaticQuant) {
|
||||
dstExpBuf_ = receiveDataCastFloatBuf_; // 内存复用
|
||||
subExpBuf_ = smoothScalesBuf_; // 内存复用
|
||||
} else {
|
||||
tpipe_->InitBuffer(dstExpBuf_, maxSize_); // BS * K * 4 = 32K
|
||||
tpipe_->InitBuffer(subExpBuf_, maxSize_); // BS * K * 4 = 32K
|
||||
}
|
||||
moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK;
|
||||
if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK &&
|
||||
axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) {
|
||||
@@ -370,11 +392,36 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::Quant
|
||||
dynamicScalesTensor_ = dynamicScalesBuf_.Get<float>();
|
||||
}
|
||||
|
||||
template <TemplateDispatchTypeClass>
|
||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::TokenActiveMaskCal()
|
||||
{
|
||||
// 搬运x_active_mask, 当前仅用于计算有效token总数
|
||||
LocalTensor<half> maskTmpTensor;
|
||||
LocalTensor<half> sumOutTensor;
|
||||
LocalTensor<bool> maskInputTensor;
|
||||
axisBsAlignSize_ = Ceil(axisBS_ * sizeof(bool), UB_ALIGN) * UB_ALIGN;
|
||||
maskInputTensor = dstExpBuf_.Get<bool>();
|
||||
maskTmpTensor = subExpBuf_.Get<half>();
|
||||
sumOutTensor = gatherMaskTBuf_.Get<half>();
|
||||
DataCopyExtParams maskParams = {1U, static_cast<uint32_t>(axisBS_ * sizeof(bool)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<bool> maskCopyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(maskInputTensor, xActiveMaskGMTensor_, maskParams, maskCopyPadParams);
|
||||
SyncFunc<AscendC::HardEvent::MTE2_V>();
|
||||
LocalTensor<int8_t> maskInputInt8Tensor = maskInputTensor.ReinterpretCast<int8_t>();
|
||||
Cast(maskTmpTensor, maskInputInt8Tensor, RoundMode::CAST_NONE, axisBS_);
|
||||
PipeBarrier<PIPE_V>();
|
||||
SumParams params{1, axisBsAlignSize_, axisBS_};
|
||||
Sum(sumOutTensor, maskTmpTensor, params);
|
||||
SyncFunc<AscendC::HardEvent::V_S>();
|
||||
activeMaskBsCnt_ = static_cast<int32_t>(sumOutTensor.GetValue(0));
|
||||
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
|
||||
}
|
||||
|
||||
template <TemplateDispatchTypeClass>
|
||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToSharedExpert()
|
||||
{
|
||||
uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_;
|
||||
uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_;
|
||||
uint32_t sendTokenNum = activeMaskBsCnt_ / sharedUsedAivNum_;
|
||||
uint32_t remainderTokenNum = activeMaskBsCnt_ % sharedUsedAivNum_;
|
||||
uint32_t newAivId = aivId_ - moeUsedAivNum_;
|
||||
uint32_t startTokenId = sendTokenNum * newAivId;
|
||||
if (newAivId < remainderTokenNum) {
|
||||
@@ -383,16 +430,16 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
||||
} else {
|
||||
startTokenId += remainderTokenNum;
|
||||
}
|
||||
if (startTokenId >= axisBS_) {
|
||||
if (startTokenId >= activeMaskBsCnt_) {
|
||||
return;
|
||||
}
|
||||
uint32_t endTokenId = startTokenId + sendTokenNum;
|
||||
for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) {
|
||||
uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum);
|
||||
uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_;
|
||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_;
|
||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ -
|
||||
epRankId_ * axisBS_ / sharedExpertRankNum_;
|
||||
uint32_t temp = (epRankId_ * activeMaskBsCnt_) / sharedExpertRankNum_;
|
||||
uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, activeMaskBsCnt_) - 1 - epRankId_;
|
||||
uint32_t preCnt = (moeOnShareRank + epRankId_) * activeMaskBsCnt_ / sharedExpertRankNum_ -
|
||||
epRankId_ * activeMaskBsCnt_ / sharedExpertRankNum_;
|
||||
GlobalTensor<ExpandXOutType> dstWinGMTensor;
|
||||
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) +
|
||||
expertPerSizeOnWin_ * epRankId_));
|
||||
@@ -440,7 +487,7 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
||||
template <TemplateDispatchTypeClass>
|
||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendToMoeExpert()
|
||||
{
|
||||
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
||||
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
|
||||
uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_;
|
||||
uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_;
|
||||
uint32_t startTokenId = sendTokenNum * aivId_;
|
||||
@@ -496,7 +543,12 @@ __aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::SendT
|
||||
template <TemplateDispatchTypeClass>
|
||||
__aicore__ inline void CamMoeDistributeDispatch<TemplateDispatchTypeFunc>::AlltoAllDispatch()
|
||||
{
|
||||
uint32_t expertIdsCnt = axisBS_ * axisK_;
|
||||
activeMaskBsCnt_ = axisBS_;
|
||||
sendToMoeExpTokenCnt_ = activeMaskBsCnt_ * axisK_;
|
||||
if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) {
|
||||
TokenActiveMaskCal();
|
||||
}
|
||||
uint32_t expertIdsCnt = activeMaskBsCnt_ * axisK_;
|
||||
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U};
|
||||
DataCopyPadExtParams<int32_t> copyPadParams{false, 0U, 0U, 0U};
|
||||
DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams);
|
||||
|
||||
@@ -14,5 +14,8 @@
|
||||
|
||||
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
||||
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
||||
|
||||
#define TemplateDispatchTypeClass \
|
||||
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
||||
bool IsNeedAllgater, uint32_t EXEC_FLAG
|
||||
#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
|
||||
@@ -70,5 +70,6 @@ constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0;
|
||||
constexpr uint32_t WORKSPACE_STAGES = 4;
|
||||
|
||||
constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0);
|
||||
constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2);
|
||||
|
||||
#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H
|
||||
|
||||
@@ -640,8 +640,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const at::Tensor &expert_scales,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
const c10::optional<at::Tensor> &x_active_mask,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
@@ -674,8 +675,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
|
||||
gmm1_permuted_weight_scale,
|
||||
gmm2_weight,
|
||||
gmm2_weight_scale,
|
||||
expert_smooth_scales,
|
||||
expert_scales,
|
||||
expert_smooth_scales,
|
||||
x_active_mask,
|
||||
//input attrs
|
||||
group_ep_ptr,
|
||||
ep_rank_size,
|
||||
@@ -1188,7 +1190,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
|
||||
" Tensor gmm1_permuted_weight_scale,"
|
||||
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
|
||||
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
|
||||
" Tensor expert_scales, Tensor? expert_smooth_scales=None,"
|
||||
" Tensor? x_active_mask=None,"
|
||||
" str group_ep='',"
|
||||
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
|
||||
" int shared_expert_num=1, int shared_expert_rank_num=0,"
|
||||
|
||||
@@ -161,8 +161,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
|
||||
const at::Tensor &gmm1_permuted_weight_scale,
|
||||
const at::Tensor &gmm2_weight,
|
||||
const at::Tensor &gmm2_weight_scale,
|
||||
const at::Tensor &expert_scales,
|
||||
const c10::optional<at::Tensor> &expert_smooth_scales,
|
||||
const c10::optional<at::Tensor> &expert_scales,
|
||||
const c10::optional<at::Tensor> &x_active_mask,
|
||||
c10::string_view group_ep,
|
||||
int64_t ep_rank_size,
|
||||
int64_t ep_rank_id,
|
||||
|
||||
@@ -16,6 +16,19 @@ torch.manual_seed(42)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
BASE_KWARGS = {
|
||||
"batch_size": 64,
|
||||
"token_hidden_size": 7168,
|
||||
"moe_intermediate_size": 2048,
|
||||
"ep_world_size": 16,
|
||||
"moe_expert_num": 64,
|
||||
"shared_expert_rank_num": 0,
|
||||
"top_k": 8,
|
||||
"test_bfloat16": True,
|
||||
"enable_dynamic_bs": False,
|
||||
"test_graph": False,
|
||||
"with_mc2_mask": False
|
||||
}
|
||||
|
||||
|
||||
def redirect_output(log_file_path):
|
||||
@@ -115,11 +128,14 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales)
|
||||
def forward(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
return self._apply_ops(x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask)
|
||||
|
||||
|
||||
class SmallOps(DecodeMoeOps):
|
||||
@@ -144,11 +160,13 @@ class SmallOps(DecodeMoeOps):
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -200,6 +218,7 @@ class SmallOps(DecodeMoeOps):
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
ep_send_counts=ep_send_counts,
|
||||
expert_scales=expert_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_world_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -237,7 +256,8 @@ class FusionOp(DecodeMoeOps):
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales,
|
||||
x_active_mask):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
@@ -245,8 +265,9 @@ class FusionOp(DecodeMoeOps):
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
expert_scales=expert_scales,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
x_active_mask=x_active_mask,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
ep_rank_size=self.ep_world_size,
|
||||
ep_rank_id=self.global_rank_id,
|
||||
@@ -267,12 +288,13 @@ def generate_datas(batch_size,
|
||||
shared_expert_rank_num=0,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False):
|
||||
enable_dynamic_bs=False,
|
||||
with_mc2_mask=False):
|
||||
is_shared_expert = global_rank_id < shared_expert_rank_num
|
||||
moe_expert_num_per_rank = moe_expert_num // (ep_world_size -
|
||||
shared_expert_rank_num)
|
||||
actual_bs = int(
|
||||
torch.randint(1, batch_size, [1]).item(
|
||||
torch.randint(2 if with_mc2_mask else 1, batch_size, [1]).item(
|
||||
) if enable_dynamic_bs else batch_size)
|
||||
local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank
|
||||
gmm1_input_dim = token_hidden_size
|
||||
@@ -317,9 +339,16 @@ def generate_datas(batch_size,
|
||||
else:
|
||||
x = x.half()
|
||||
smooth_sales = None
|
||||
return (x, expert_ids, smooth_sales, expert_scales), \
|
||||
x_active_mask = None
|
||||
valid_token_num = actual_bs
|
||||
if with_mc2_mask:
|
||||
valid_token_num = int(torch.randint(1, actual_bs, [1]).item())
|
||||
x_active_mask = torch.cat(
|
||||
(torch.ones(valid_token_num),
|
||||
torch.zeros(actual_bs - valid_token_num))).bool()
|
||||
return (x, expert_ids, smooth_sales, expert_scales, x_active_mask), \
|
||||
(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \
|
||||
actual_bs
|
||||
actual_bs, valid_token_num
|
||||
|
||||
|
||||
def run_once(local_rank_id,
|
||||
@@ -332,7 +361,8 @@ def run_once(local_rank_id,
|
||||
top_k=8,
|
||||
test_bfloat16=True,
|
||||
enable_dynamic_bs=False,
|
||||
test_graph=False):
|
||||
test_graph=False,
|
||||
with_mc2_mask=False):
|
||||
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
|
||||
) if output_to_file(local_rank_id) else None
|
||||
global_rank_id = local_rank_id # 单机
|
||||
@@ -358,8 +388,8 @@ def run_once(local_rank_id,
|
||||
parameter = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
input_datas, weight_datas, actual_bs = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs)
|
||||
input_datas, weight_datas, actual_bs, valid_token_num = generate_datas(
|
||||
*parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask)
|
||||
input_datas = [
|
||||
data.npu() if data is not None else None for data in input_datas
|
||||
]
|
||||
@@ -382,8 +412,8 @@ def run_once(local_rank_id,
|
||||
if log_file is not None:
|
||||
log_file.close()
|
||||
small_op_count_output = from_inclusive_prefix_sum(small_op_count_output)
|
||||
torch.testing.assert_close(small_op_token_output.cpu(),
|
||||
fused_op_token_output.cpu(),
|
||||
torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(),
|
||||
fused_op_token_output[0:valid_token_num].cpu(),
|
||||
atol=2.0,
|
||||
rtol=0.02)
|
||||
torch.testing.assert_close(small_op_count_output.cpu(),
|
||||
@@ -394,18 +424,16 @@ def run_once(local_rank_id,
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test():
|
||||
batch_size = 64
|
||||
token_hidden_size = 7168
|
||||
moe_intermediate_size = 2048
|
||||
ep_world_size = 16
|
||||
moe_expert_num = 64
|
||||
shared_expert_rank_num = 0
|
||||
top_k = 8
|
||||
test_bfloat16 = True
|
||||
enable_dynamic_bs = False
|
||||
test_graph = False
|
||||
args = (batch_size, token_hidden_size, moe_intermediate_size,
|
||||
ep_world_size, moe_expert_num, shared_expert_rank_num, top_k,
|
||||
test_bfloat16, enable_dynamic_bs, test_graph)
|
||||
mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True)
|
||||
def test_dispatch_gmm_combine_decode_base():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
|
||||
def test_dispatch_gmm_combine_decode_with_mc2_mask():
|
||||
custom_kwargs = BASE_KWARGS
|
||||
custom_kwargs["with_mc2_mask"] = True
|
||||
ep_world_size = custom_kwargs["ep_world_size"]
|
||||
custom_args = tuple(custom_kwargs.values())
|
||||
mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True)
|
||||
|
||||
Reference in New Issue
Block a user