diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp index ff037903..3799b098 100644 --- a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp @@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* const aclTensor* probs, const char* group, int64_t maxOutputSize, bool transB, bool weightNz, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor); extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream); @@ -59,7 +59,7 @@ aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const ac const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor) { bool transB = false; @@ -67,7 +67,7 @@ aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const ac aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group, maxOutputSize, transB, weightNz, - out, workspaceSize, executor); + out, expertTokenNums, workspaceSize, executor); return ret; } diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h index 871a30b0..556e501a 100644 --- a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h @@ -43,7 +43,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWor const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor); /** diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp index 6b8a7a33..a393fed6 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp @@ -62,6 +62,11 @@ class DispatchFFNCombine : public OpDef { .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND}); + this->Output("expert_token_nums") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group").AttrType(REQUIRED).String(); this->Attr("M").AttrType(OPTIONAL).Int(); diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp index db3cf771..476f43e5 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp @@ -20,32 +20,32 @@ using namespace AscendC; using namespace DispatchFFNCombineImpl; extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs, - GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); if (TILING_KEY_IS(1000000)) { KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000001)) { KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000010)) { KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000011)) { KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } } \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index 704809dc..31c0471f 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -55,7 +55,7 @@ class DispatchFFNCombine { public: __aicore__ inline DispatchFFNCombine() {}; __aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, - GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM); + GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM); __aicore__ inline void Process(); @@ -68,6 +68,7 @@ private: GM_ADDR scale2GM_; GM_ADDR probs_; GM_ADDR outGM_; + GM_ADDR gmExpertTokenNums_; GM_ADDR workspaceGM_; GM_ADDR moeInitRoutingQuantV2Scale = nullptr; @@ -112,7 +113,7 @@ private: template __aicore__ inline void DispatchFFNCombine::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, - GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM; @@ -127,6 +128,7 @@ __aicore__ inline void DispatchFFNCombine::Init(GM_ADDR xGM, probs_ = probs; outGM_ = outGM; + gmExpertTokenNums_ = expertTokenNums; workspaceGM_ = workspaceGM; @@ -268,7 +270,7 @@ __aicore__ inline void DispatchFFNCombine::Process() outGM_, layoutD1, layoutD2, expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset, expertTokensBeforeCapacity, probs_, - workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData}; + workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData}; //Call kernel MatmulKernel kernel(params); kernel(params); diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 462004d2..e0956c92 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -96,6 +96,7 @@ public: LayoutD1 layoutD1; LayoutD2 layoutD2; GM_ADDR ptrWorkspace; + GM_ADDR ptrExpertTokenNums; int32_t EP; int32_t listLen; int32_t expertPerRank; @@ -139,7 +140,7 @@ public: GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_, GM_ADDR moeInitRoutingQuantV2Offset_, GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_, - GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, + GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_, optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_ ) : problemShape(problemShape_), EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), @@ -155,7 +156,7 @@ public: expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_), - ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_), + ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_), moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_) { } @@ -228,7 +229,7 @@ private: tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); } template @@ -335,15 +336,6 @@ private: AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul syncgmmIdx++; - constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; - __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK]; - - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr(loopIdx, params.ptrB1)); - scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale1)); - } AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -356,8 +348,8 @@ private: AscendC::GlobalTensor gmB1; AscendC::GlobalTensor gmS; int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx])); - gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx])); + gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); + gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); AscendC::PipeBarrier(); @@ -455,14 +447,6 @@ private: lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } - constexpr uint32_t MAX_EXPERTS_PER_RANK = 8; - __gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK]; - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(loopIdx, params.ptrB2)); - scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale2)); - } AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -476,8 +460,8 @@ private: AscendC::GlobalTensor gmS2; AscendC::PipeBarrier(); int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx])); - gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx])); + gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB2))); + gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale2))); if (currentM <= L1TileShape::M) { gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); @@ -596,7 +580,6 @@ private: AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { if (dstEpIdx == params.rank) { continue; @@ -639,6 +622,13 @@ private: GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } AscendC::SyncAll(); + + AscendC::GlobalTensor ExpertTokenNums; + ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); + AscendC::GlobalTensor LcalCumsumMM; + LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t))); + CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); + AscendC::SyncAll(); uint16_t syncgmm1Idx = 0; AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx++; diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp index e699e6e6..d717c4b0 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp +++ b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp @@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(const aclTen const aclTensor* probs, const char* group, int64_t maxOutputSize, bool transB, bool weightNz, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor); extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream); @@ -59,7 +59,7 @@ aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, cons const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor) { bool transB = false; @@ -67,7 +67,7 @@ aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, cons aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group, maxOutputSize, transB, weightNz, - out, workspaceSize, executor); + out, expertTokenNums, workspaceSize, executor); return ret; } diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h index a14f61fb..355d905f 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h +++ b/csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h @@ -25,7 +25,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16Ge const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, - const aclTensor* out, + const aclTensor* out, const aclTensor* expertTokenNums, uint64_t* workspaceSize, aclOpExecutor** executor); diff --git a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp index 00bf2320..036823e9 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp +++ b/csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp @@ -62,6 +62,11 @@ class DispatchFFNCombineBF16 : public OpDef { .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16}) .Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expert_token_nums") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group").AttrType(REQUIRED).String(); this->Attr("M").AttrType(OPTIONAL).Int(); diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp index 370ffd82..8cd46a60 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.cpp @@ -20,32 +20,32 @@ using namespace AscendC; using namespace DispatchFFNCombineBF16Impl; extern "C" __global__ __aicore__ void dispatch_ffn_combine_bf16(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs, - GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData); if (TILING_KEY_IS(1000000)) { KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); DispatchFFNCombineBF16 op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000001)) { KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); DispatchFFNCombineBF16 op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000010)) { KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); DispatchFFNCombineBF16 op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } else if (TILING_KEY_IS(1000011)) { KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM); DispatchFFNCombineBF16 op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM); op.Process(); } } \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h index 12bd7949..14945eff 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16.h @@ -57,7 +57,7 @@ class DispatchFFNCombineBF16 { public: __aicore__ inline DispatchFFNCombineBF16() {}; __aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, - GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM); + GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM); __aicore__ inline void Process(); @@ -70,6 +70,7 @@ private: GM_ADDR scale2GM_; GM_ADDR probs_; GM_ADDR outGM_; + GM_ADDR gmExpertTokenNums_; GM_ADDR workspaceGM_; GM_ADDR moeInitRoutingQuantV2Scale = nullptr; @@ -114,7 +115,7 @@ private: template __aicore__ inline void DispatchFFNCombineBF16::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, - GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData); auto tiling = (__gm__ DispatchFFNCombineBF16TilingData*)tilingGM; @@ -129,6 +130,7 @@ __aicore__ inline void DispatchFFNCombineBF16::Init(GM_ADDR probs_ = probs; outGM_ = outGM; + gmExpertTokenNums_ = expertTokenNums; workspaceGM_ = workspaceGM; @@ -278,7 +280,7 @@ __aicore__ inline void DispatchFFNCombineBF16::Process() outGM_, layoutD1, layoutD2, expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset, expertTokensBeforeCapacity, probs_, - workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData}; + workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData}; //Call kernel MatmulKernel kernel(params); kernel(params); diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp index 5877a370..231e9f72 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -97,6 +97,7 @@ public: LayoutD1 layoutD1; LayoutD2 layoutD2; GM_ADDR ptrWorkspace; + GM_ADDR ptrExpertTokenNums; int32_t EP; int32_t listLen; int32_t expertPerRank; @@ -141,7 +142,7 @@ public: GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_, GM_ADDR moeInitRoutingQuantV2Offset_, GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_, - GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, + GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_, optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData_, GM_ADDR symmetricPtr_ = nullptr ) : problemShape(problemShape_), @@ -158,7 +159,7 @@ public: expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_), - ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_), + ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_), moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_) { moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingData_.vbsComputeParamsOp; @@ -518,14 +519,6 @@ CATLASS_DEVICE int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; - __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK]; - - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr(loopIdx, params.ptrB1)); - scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale1)); - } AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -538,8 +531,8 @@ CATLASS_DEVICE AscendC::GlobalTensor gmB1; AscendC::GlobalTensor gmS; int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx])); - gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx])); + gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); + gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); AscendC::PipeBarrier(); @@ -647,13 +640,6 @@ CATLASS_DEVICE lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } - __gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK]; - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(loopIdx, params.ptrB2)); - scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale2)); - } AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -667,8 +653,8 @@ CATLASS_DEVICE AscendC::GlobalTensor gmS2; AscendC::PipeBarrier(); int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx])); - gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx])); + gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB2))); + gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale2))); if (currentM <= L1TileShape::M) { gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); @@ -766,6 +752,13 @@ CATLASS_DEVICE AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx++; + AscendC::GlobalTensor ExpertTokenNums; + ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); + AscendC::GlobalTensor LcalCumsumMM; + LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t))); + CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); + AscendC::SyncAll(); + uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; int32_t groupIdxDeq = 0; @@ -776,7 +769,8 @@ CATLASS_DEVICE uint32_t prevGroupSum1 = 0; uint32_t dequantSum = 0; int32_t syncLoopIdx = -1; - BlockEpilogue1 blockEpilogue(resource); + uint32_t n = params.problemShape.n(); + BlockEpilogue1 blockEpilogue(resource, n); for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp index 0862458f..ffd9b416 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp @@ -85,16 +85,16 @@ public: }; CATLASS_DEVICE - BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + BlockEpilogue(Arch::Resource const &resource, int32_t n, Params const ¶ms = Params{}) : params(params) { size_t ubOffset = 0; int32_t eventVMTE2 = 0; int32_t eventMTE2V = 0; int32_t eventMTE3V = 0; int32_t eventVMTE3 = 0; - constexpr uint32_t blockN = 4096; - constexpr uint32_t ChunkTileLen = blockN / 2; - constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2; + uint32_t blockN = n; + uint32_t ChunkTileLen = blockN / 2; + uint32_t HalfChunkTileLen = ChunkTileLen / 2; for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index b08bb92f..146123ec 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -725,7 +725,7 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor return; } -at::Tensor& dispatch_ffn_combine( +std::tuple dispatch_ffn_combine( const at::Tensor& x, const at::TensorList& weight1, const at::TensorList& weight2, @@ -735,7 +735,8 @@ at::Tensor& dispatch_ffn_combine( const at::Tensor& probs, c10::string_view group, int64_t max_output_size, - at::Tensor& out + at::Tensor& out, + at::Tensor& expert_token_nums ) { char *group_ep_ptr = const_cast(group.data()); bool is_int8 = weight1[0].dtype() == at::kChar; @@ -750,7 +751,8 @@ at::Tensor& dispatch_ffn_combine( probs, group_ep_ptr, max_output_size, - out); + out, + expert_token_nums); } else { EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16, x, @@ -762,9 +764,10 @@ at::Tensor& dispatch_ffn_combine( probs, group_ep_ptr, max_output_size, - out); + out, + expert_token_nums); } - return out; + return {out, expert_token_nums}; } at::Tensor npu_lightning_indexer( @@ -1452,7 +1455,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def( "dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx," " Tensor[] scale1, Tensor[] scale2, Tensor probs, str group," - " int max_output_size, Tensor! out) -> Tensor" + " int max_output_size, Tensor! out, Tensor! expert_token_nums) -> (Tensor out, Tensor expert_token_nums)" ); ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index b19fc643..af550134 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -194,7 +194,7 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor return; } -at::Tensor& dispatch_ffn_combine_meta( +std::tuple dispatch_ffn_combine_meta( const at::Tensor& x, const at::TensorList& weight1, const at::TensorList& weight2, @@ -204,9 +204,10 @@ at::Tensor& dispatch_ffn_combine_meta( const at::Tensor& probs, c10::string_view group, int64_t max_output_size, - at::Tensor& out + at::Tensor& out, + at::Tensor& expert_token_nums ) { - return out; + return {out, expert_token_nums}; } at::Tensor npu_lightning_indexer_meta( diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py index 9f86cc0f..5897ef0c 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py @@ -126,6 +126,7 @@ class TestDisptachFFNCombine: scale2_npu.append(scale2[i].npu()) out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu() torch.ops._C_ascend.dispatch_ffn_combine( x=x, @@ -138,6 +139,7 @@ class TestDisptachFFNCombine: group=self.hcomm_info, max_output_size=512, out=out, + expert_token_nums=expert_token_nums, ) return True @@ -177,6 +179,7 @@ class TestDisptachFFNCombine: scale2_npu.append(scale2.npu()) out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu() torch.ops._C_ascend.dispatch_ffn_combine( x=x, @@ -189,6 +192,7 @@ class TestDisptachFFNCombine: group=self.hcomm_info, max_output_size=512, out=out, + expert_token_nums=expert_token_nums, ) return True diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py index 5b50c07a..91289940 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine_bf16.py @@ -126,6 +126,7 @@ class TestDisptachFFNCombine: scale2_npu.append(scale2[i].npu()) out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu() torch.ops._C_ascend.dispatch_ffn_combine( x=x, @@ -138,6 +139,7 @@ class TestDisptachFFNCombine: group=self.hcomm_info, max_output_size=512, out=out, + expert_token_nums=expert_token_nums, ) return True @@ -177,6 +179,7 @@ class TestDisptachFFNCombine: scale2_npu.append(scale2.npu()) out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu() torch.ops._C_ascend.dispatch_ffn_combine( x=x, @@ -189,6 +192,7 @@ class TestDisptachFFNCombine: group=self.hcomm_info, max_output_size=512, out=out, + expert_token_nums=expert_token_nums, ) return True diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 458557e9..d135968c 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -315,6 +315,7 @@ class FusedMC2CommImpl(MoECommMethod): expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) + expert_token_nums = torch.zeros([self.moe_config.num_local_experts], dtype=torch.int32) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore x=hidden_states, weight1=w1, @@ -326,7 +327,9 @@ class FusedMC2CommImpl(MoECommMethod): group=self.token_dispatcher.moe_all_to_all_group_name, max_output_size=65536, out=out, + expert_token_nums=expert_token_nums, ) + expert_tokens = expert_token_nums elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." group_list_type = 1