[Refactor] Add expert processed token count output for DispatchFFNCombine/DispatchFFNCombineBF16 (#6402)
### What this PR does / why we need it?
Add New Output for Expert Token Count
An additional output tensor expert_token_nums is added to both operators
to meet the requirement of tracking token distribution among experts:
Tensor Name: expert_token_nums
Dimension: 1D tensor
Shape: (local_expert_num,)
Data Type: int32
Semantics: Represents the number of tokens actually received by each
expert on the current card.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
Signed-off-by: guanguan0308 <162653673+guanguan0308@users.noreply.github.com>
This commit is contained in:
@@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor*
|
|||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
bool transB, bool weightNz,
|
bool transB, bool weightNz,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||||
extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize,
|
extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize,
|
||||||
aclOpExecutor *executor, aclrtStream stream);
|
aclOpExecutor *executor, aclrtStream stream);
|
||||||
@@ -59,7 +59,7 @@ aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const ac
|
|||||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor)
|
uint64_t* workspaceSize, aclOpExecutor** executor)
|
||||||
{
|
{
|
||||||
bool transB = false;
|
bool transB = false;
|
||||||
@@ -67,7 +67,7 @@ aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const ac
|
|||||||
|
|
||||||
aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
|
aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
|
||||||
maxOutputSize, transB, weightNz,
|
maxOutputSize, transB, weightNz,
|
||||||
out, workspaceSize, executor);
|
out, expertTokenNums, workspaceSize, executor);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWor
|
|||||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ class DispatchFFNCombine : public OpDef {
|
|||||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
|
||||||
|
this->Output("expert_token_nums")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||||
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
|
||||||
this->Attr("group").AttrType(REQUIRED).String();
|
this->Attr("group").AttrType(REQUIRED).String();
|
||||||
this->Attr("M").AttrType(OPTIONAL).Int();
|
this->Attr("M").AttrType(OPTIONAL).Int();
|
||||||
|
|||||||
@@ -20,32 +20,32 @@
|
|||||||
using namespace AscendC;
|
using namespace AscendC;
|
||||||
using namespace DispatchFFNCombineImpl;
|
using namespace DispatchFFNCombineImpl;
|
||||||
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
||||||
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||||
{
|
{
|
||||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||||
if (TILING_KEY_IS(1000000)) {
|
if (TILING_KEY_IS(1000000)) {
|
||||||
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000001)) {
|
} else if (TILING_KEY_IS(1000001)) {
|
||||||
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
|
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000010)) {
|
} else if (TILING_KEY_IS(1000010)) {
|
||||||
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000011)) {
|
} else if (TILING_KEY_IS(1000011)) {
|
||||||
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
|
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -55,7 +55,7 @@ class DispatchFFNCombine {
|
|||||||
public:
|
public:
|
||||||
__aicore__ inline DispatchFFNCombine() {};
|
__aicore__ inline DispatchFFNCombine() {};
|
||||||
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
||||||
__aicore__ inline void Process();
|
__aicore__ inline void Process();
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +68,7 @@ private:
|
|||||||
GM_ADDR scale2GM_;
|
GM_ADDR scale2GM_;
|
||||||
GM_ADDR probs_;
|
GM_ADDR probs_;
|
||||||
GM_ADDR outGM_;
|
GM_ADDR outGM_;
|
||||||
|
GM_ADDR gmExpertTokenNums_;
|
||||||
GM_ADDR workspaceGM_;
|
GM_ADDR workspaceGM_;
|
||||||
|
|
||||||
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
||||||
@@ -112,7 +113,7 @@ private:
|
|||||||
|
|
||||||
template <TemplateMMA2AClass>
|
template <TemplateMMA2AClass>
|
||||||
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||||
{
|
{
|
||||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
|
||||||
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
|
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
|
||||||
@@ -127,6 +128,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM,
|
|||||||
probs_ = probs;
|
probs_ = probs;
|
||||||
|
|
||||||
outGM_ = outGM;
|
outGM_ = outGM;
|
||||||
|
gmExpertTokenNums_ = expertTokenNums;
|
||||||
|
|
||||||
workspaceGM_ = workspaceGM;
|
workspaceGM_ = workspaceGM;
|
||||||
|
|
||||||
@@ -268,7 +270,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
|||||||
outGM_, layoutD1, layoutD2,
|
outGM_, layoutD1, layoutD2,
|
||||||
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
||||||
expertTokensBeforeCapacity, probs_,
|
expertTokensBeforeCapacity, probs_,
|
||||||
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
||||||
//Call kernel
|
//Call kernel
|
||||||
MatmulKernel kernel(params);
|
MatmulKernel kernel(params);
|
||||||
kernel(params);
|
kernel(params);
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ public:
|
|||||||
LayoutD1 layoutD1;
|
LayoutD1 layoutD1;
|
||||||
LayoutD2 layoutD2;
|
LayoutD2 layoutD2;
|
||||||
GM_ADDR ptrWorkspace;
|
GM_ADDR ptrWorkspace;
|
||||||
|
GM_ADDR ptrExpertTokenNums;
|
||||||
int32_t EP;
|
int32_t EP;
|
||||||
int32_t listLen;
|
int32_t listLen;
|
||||||
int32_t expertPerRank;
|
int32_t expertPerRank;
|
||||||
@@ -139,7 +140,7 @@ public:
|
|||||||
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
||||||
GM_ADDR moeInitRoutingQuantV2Offset_,
|
GM_ADDR moeInitRoutingQuantV2Offset_,
|
||||||
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
||||||
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_,
|
||||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
||||||
) : problemShape(problemShape_),
|
) : problemShape(problemShape_),
|
||||||
EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
||||||
@@ -155,7 +156,7 @@ public:
|
|||||||
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
||||||
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
||||||
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
||||||
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),
|
ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),
|
||||||
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -228,7 +229,7 @@ private:
|
|||||||
|
|
||||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||||
|
|
||||||
tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@@ -335,15 +336,6 @@ private:
|
|||||||
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
|
||||||
syncgmmIdx++;
|
syncgmmIdx++;
|
||||||
|
|
||||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
|
||||||
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
|
|
||||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
|
||||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
|
||||||
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
|
|
||||||
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
|
|
||||||
}
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
@@ -356,8 +348,8 @@ private:
|
|||||||
AscendC::GlobalTensor<ElementB> gmB1;
|
AscendC::GlobalTensor<ElementB> gmB1;
|
||||||
AscendC::GlobalTensor<ElementScale> gmS;
|
AscendC::GlobalTensor<ElementScale> gmS;
|
||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
|
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
|
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
@@ -455,14 +447,6 @@ private:
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr uint32_t MAX_EXPERTS_PER_RANK = 8;
|
|
||||||
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
|
||||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
|
||||||
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
|
|
||||||
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
|
|
||||||
}
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
@@ -476,8 +460,8 @@ private:
|
|||||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
|
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
||||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
|
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
||||||
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
@@ -596,7 +580,6 @@ private:
|
|||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
if (dstEpIdx == params.rank) {
|
if (dstEpIdx == params.rank) {
|
||||||
continue;
|
continue;
|
||||||
@@ -639,6 +622,13 @@ private:
|
|||||||
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||||
|
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||||
|
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
|
||||||
|
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
|
||||||
|
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
uint16_t syncgmm1Idx = 0;
|
uint16_t syncgmm1Idx = 0;
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
syncgmm1Idx++;
|
syncgmm1Idx++;
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(const aclTen
|
|||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
bool transB, bool weightNz,
|
bool transB, bool weightNz,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||||
extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize,
|
extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize,
|
||||||
aclOpExecutor *executor, aclrtStream stream);
|
aclOpExecutor *executor, aclrtStream stream);
|
||||||
@@ -59,7 +59,7 @@ aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, cons
|
|||||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor)
|
uint64_t* workspaceSize, aclOpExecutor** executor)
|
||||||
{
|
{
|
||||||
bool transB = false;
|
bool transB = false;
|
||||||
@@ -67,7 +67,7 @@ aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, cons
|
|||||||
|
|
||||||
aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
|
aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
|
||||||
maxOutputSize, transB, weightNz,
|
maxOutputSize, transB, weightNz,
|
||||||
out, workspaceSize, executor);
|
out, expertTokenNums, workspaceSize, executor);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16Ge
|
|||||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out, const aclTensor* expertTokenNums,
|
||||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ class DispatchFFNCombineBF16 : public OpDef {
|
|||||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
|
||||||
.Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
this->Output("expert_token_nums")
|
||||||
|
.ParamType(REQUIRED)
|
||||||
|
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||||
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
|
||||||
this->Attr("group").AttrType(REQUIRED).String();
|
this->Attr("group").AttrType(REQUIRED).String();
|
||||||
this->Attr("M").AttrType(OPTIONAL).Int();
|
this->Attr("M").AttrType(OPTIONAL).Int();
|
||||||
|
|||||||
@@ -20,32 +20,32 @@
|
|||||||
using namespace AscendC;
|
using namespace AscendC;
|
||||||
using namespace DispatchFFNCombineBF16Impl;
|
using namespace DispatchFFNCombineBF16Impl;
|
||||||
extern "C" __global__ __aicore__ void dispatch_ffn_combine_bf16(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
extern "C" __global__ __aicore__ void dispatch_ffn_combine_bf16(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
|
||||||
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||||
{
|
{
|
||||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
|
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
|
||||||
if (TILING_KEY_IS(1000000)) {
|
if (TILING_KEY_IS(1000000)) {
|
||||||
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
|
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000001)) {
|
} else if (TILING_KEY_IS(1000001)) {
|
||||||
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, false> op;
|
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, false> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000010)) {
|
} else if (TILING_KEY_IS(1000010)) {
|
||||||
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
|
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, false, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
} else if (TILING_KEY_IS(1000011)) {
|
} else if (TILING_KEY_IS(1000011)) {
|
||||||
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
|
||||||
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineBF16TilingData, tilingData, tilingGM);
|
||||||
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, true> op;
|
DispatchFFNCombineBF16<DTYPE_A, DTYPE_W1, DTYPE_OUT, true, true> op;
|
||||||
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
|
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
|
||||||
op.Process();
|
op.Process();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ class DispatchFFNCombineBF16 {
|
|||||||
public:
|
public:
|
||||||
__aicore__ inline DispatchFFNCombineBF16() {};
|
__aicore__ inline DispatchFFNCombineBF16() {};
|
||||||
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
|
||||||
__aicore__ inline void Process();
|
__aicore__ inline void Process();
|
||||||
|
|
||||||
|
|
||||||
@@ -70,6 +70,7 @@ private:
|
|||||||
GM_ADDR scale2GM_;
|
GM_ADDR scale2GM_;
|
||||||
GM_ADDR probs_;
|
GM_ADDR probs_;
|
||||||
GM_ADDR outGM_;
|
GM_ADDR outGM_;
|
||||||
|
GM_ADDR gmExpertTokenNums_;
|
||||||
GM_ADDR workspaceGM_;
|
GM_ADDR workspaceGM_;
|
||||||
|
|
||||||
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
|
||||||
@@ -114,7 +115,7 @@ private:
|
|||||||
|
|
||||||
template <TemplateMMA2AClass>
|
template <TemplateMMA2AClass>
|
||||||
__aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
__aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
|
||||||
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
|
||||||
{
|
{
|
||||||
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
|
REGISTER_TILING_DEFAULT(DispatchFFNCombineBF16TilingData);
|
||||||
auto tiling = (__gm__ DispatchFFNCombineBF16TilingData*)tilingGM;
|
auto tiling = (__gm__ DispatchFFNCombineBF16TilingData*)tilingGM;
|
||||||
@@ -129,6 +130,7 @@ __aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Init(GM_ADDR
|
|||||||
probs_ = probs;
|
probs_ = probs;
|
||||||
|
|
||||||
outGM_ = outGM;
|
outGM_ = outGM;
|
||||||
|
gmExpertTokenNums_ = expertTokenNums;
|
||||||
|
|
||||||
workspaceGM_ = workspaceGM;
|
workspaceGM_ = workspaceGM;
|
||||||
|
|
||||||
@@ -278,7 +280,7 @@ __aicore__ inline void DispatchFFNCombineBF16<TemplateMMA2ACFunc>::Process()
|
|||||||
outGM_, layoutD1, layoutD2,
|
outGM_, layoutD1, layoutD2,
|
||||||
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
|
||||||
expertTokensBeforeCapacity, probs_,
|
expertTokensBeforeCapacity, probs_,
|
||||||
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
|
||||||
//Call kernel
|
//Call kernel
|
||||||
MatmulKernel kernel(params);
|
MatmulKernel kernel(params);
|
||||||
kernel(params);
|
kernel(params);
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ public:
|
|||||||
LayoutD1 layoutD1;
|
LayoutD1 layoutD1;
|
||||||
LayoutD2 layoutD2;
|
LayoutD2 layoutD2;
|
||||||
GM_ADDR ptrWorkspace;
|
GM_ADDR ptrWorkspace;
|
||||||
|
GM_ADDR ptrExpertTokenNums;
|
||||||
int32_t EP;
|
int32_t EP;
|
||||||
int32_t listLen;
|
int32_t listLen;
|
||||||
int32_t expertPerRank;
|
int32_t expertPerRank;
|
||||||
@@ -141,7 +142,7 @@ public:
|
|||||||
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
|
||||||
GM_ADDR moeInitRoutingQuantV2Offset_,
|
GM_ADDR moeInitRoutingQuantV2Offset_,
|
||||||
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
|
||||||
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_,
|
||||||
optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData_,
|
optiling::MoeInitRoutingV2TilingData moeInitRoutingQuantV2TilingData_,
|
||||||
GM_ADDR symmetricPtr_ = nullptr
|
GM_ADDR symmetricPtr_ = nullptr
|
||||||
) : problemShape(problemShape_),
|
) : problemShape(problemShape_),
|
||||||
@@ -158,7 +159,7 @@ public:
|
|||||||
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
||||||
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
||||||
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
|
||||||
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_),
|
ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_),
|
||||||
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
|
||||||
{
|
{
|
||||||
moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingData_.vbsComputeParamsOp;
|
moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingData_.vbsComputeParamsOp;
|
||||||
@@ -518,14 +519,6 @@ CATLASS_DEVICE
|
|||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
|
|
||||||
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
|
|
||||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
|
||||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
|
||||||
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
|
|
||||||
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
|
|
||||||
}
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
@@ -538,8 +531,8 @@ CATLASS_DEVICE
|
|||||||
AscendC::GlobalTensor<ElementB> gmB1;
|
AscendC::GlobalTensor<ElementB> gmB1;
|
||||||
AscendC::GlobalTensor<ElementScale> gmS;
|
AscendC::GlobalTensor<ElementScale> gmS;
|
||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
|
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
|
||||||
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
|
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
|
||||||
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
@@ -647,13 +640,6 @@ CATLASS_DEVICE
|
|||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
|
|
||||||
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
|
||||||
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
|
||||||
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
|
|
||||||
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
|
|
||||||
}
|
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
@@ -667,8 +653,8 @@ CATLASS_DEVICE
|
|||||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||||
AscendC::PipeBarrier<PIPE_ALL>();
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
|
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
|
||||||
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
|
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
|
||||||
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
@@ -766,6 +752,13 @@ CATLASS_DEVICE
|
|||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
|
||||||
syncgmm1Idx++;
|
syncgmm1Idx++;
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
|
||||||
|
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
|
||||||
|
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
|
||||||
|
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
|
||||||
|
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
|
||||||
|
AscendC::SyncAll<true>();
|
||||||
|
|
||||||
uint32_t curGroupOffset = 0;
|
uint32_t curGroupOffset = 0;
|
||||||
int32_t prevSumBeforeRank = 0;
|
int32_t prevSumBeforeRank = 0;
|
||||||
int32_t groupIdxDeq = 0;
|
int32_t groupIdxDeq = 0;
|
||||||
@@ -776,7 +769,8 @@ CATLASS_DEVICE
|
|||||||
uint32_t prevGroupSum1 = 0;
|
uint32_t prevGroupSum1 = 0;
|
||||||
uint32_t dequantSum = 0;
|
uint32_t dequantSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
BlockEpilogue1 blockEpilogue(resource);
|
uint32_t n = params.problemShape.n();
|
||||||
|
BlockEpilogue1 blockEpilogue(resource, n);
|
||||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||||
|
|||||||
@@ -85,16 +85,16 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
BlockEpilogue(Arch::Resource<ArchTag> const &resource, int32_t n, Params const ¶ms = Params{}) : params(params)
|
||||||
{
|
{
|
||||||
size_t ubOffset = 0;
|
size_t ubOffset = 0;
|
||||||
int32_t eventVMTE2 = 0;
|
int32_t eventVMTE2 = 0;
|
||||||
int32_t eventMTE2V = 0;
|
int32_t eventMTE2V = 0;
|
||||||
int32_t eventMTE3V = 0;
|
int32_t eventMTE3V = 0;
|
||||||
int32_t eventVMTE3 = 0;
|
int32_t eventVMTE3 = 0;
|
||||||
constexpr uint32_t blockN = 4096;
|
uint32_t blockN = n;
|
||||||
constexpr uint32_t ChunkTileLen = blockN / 2;
|
uint32_t ChunkTileLen = blockN / 2;
|
||||||
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||||
|
|||||||
@@ -725,7 +725,7 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor& dispatch_ffn_combine(
|
std::tuple<at::Tensor&, at::Tensor&> dispatch_ffn_combine(
|
||||||
const at::Tensor& x,
|
const at::Tensor& x,
|
||||||
const at::TensorList& weight1,
|
const at::TensorList& weight1,
|
||||||
const at::TensorList& weight2,
|
const at::TensorList& weight2,
|
||||||
@@ -735,7 +735,8 @@ at::Tensor& dispatch_ffn_combine(
|
|||||||
const at::Tensor& probs,
|
const at::Tensor& probs,
|
||||||
c10::string_view group,
|
c10::string_view group,
|
||||||
int64_t max_output_size,
|
int64_t max_output_size,
|
||||||
at::Tensor& out
|
at::Tensor& out,
|
||||||
|
at::Tensor& expert_token_nums
|
||||||
) {
|
) {
|
||||||
char *group_ep_ptr = const_cast<char *>(group.data());
|
char *group_ep_ptr = const_cast<char *>(group.data());
|
||||||
bool is_int8 = weight1[0].dtype() == at::kChar;
|
bool is_int8 = weight1[0].dtype() == at::kChar;
|
||||||
@@ -750,7 +751,8 @@ at::Tensor& dispatch_ffn_combine(
|
|||||||
probs,
|
probs,
|
||||||
group_ep_ptr,
|
group_ep_ptr,
|
||||||
max_output_size,
|
max_output_size,
|
||||||
out);
|
out,
|
||||||
|
expert_token_nums);
|
||||||
} else {
|
} else {
|
||||||
EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16,
|
EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16,
|
||||||
x,
|
x,
|
||||||
@@ -762,9 +764,10 @@ at::Tensor& dispatch_ffn_combine(
|
|||||||
probs,
|
probs,
|
||||||
group_ep_ptr,
|
group_ep_ptr,
|
||||||
max_output_size,
|
max_output_size,
|
||||||
out);
|
out,
|
||||||
|
expert_token_nums);
|
||||||
}
|
}
|
||||||
return out;
|
return {out, expert_token_nums};
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor npu_lightning_indexer(
|
at::Tensor npu_lightning_indexer(
|
||||||
@@ -1452,7 +1455,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
ops.def(
|
ops.def(
|
||||||
"dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx,"
|
"dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx,"
|
||||||
" Tensor[] scale1, Tensor[] scale2, Tensor probs, str group,"
|
" Tensor[] scale1, Tensor[] scale2, Tensor probs, str group,"
|
||||||
" int max_output_size, Tensor! out) -> Tensor"
|
" int max_output_size, Tensor! out, Tensor! expert_token_nums) -> (Tensor out, Tensor expert_token_nums)"
|
||||||
);
|
);
|
||||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||||
|
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor& dispatch_ffn_combine_meta(
|
std::tuple<at::Tensor&, at::Tensor&> dispatch_ffn_combine_meta(
|
||||||
const at::Tensor& x,
|
const at::Tensor& x,
|
||||||
const at::TensorList& weight1,
|
const at::TensorList& weight1,
|
||||||
const at::TensorList& weight2,
|
const at::TensorList& weight2,
|
||||||
@@ -204,9 +204,10 @@ at::Tensor& dispatch_ffn_combine_meta(
|
|||||||
const at::Tensor& probs,
|
const at::Tensor& probs,
|
||||||
c10::string_view group,
|
c10::string_view group,
|
||||||
int64_t max_output_size,
|
int64_t max_output_size,
|
||||||
at::Tensor& out
|
at::Tensor& out,
|
||||||
|
at::Tensor& expert_token_nums
|
||||||
) {
|
) {
|
||||||
return out;
|
return {out, expert_token_nums};
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor npu_lightning_indexer_meta(
|
at::Tensor npu_lightning_indexer_meta(
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class TestDisptachFFNCombine:
|
|||||||
scale2_npu.append(scale2[i].npu())
|
scale2_npu.append(scale2[i].npu())
|
||||||
|
|
||||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu()
|
||||||
|
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -138,6 +139,7 @@ class TestDisptachFFNCombine:
|
|||||||
group=self.hcomm_info,
|
group=self.hcomm_info,
|
||||||
max_output_size=512,
|
max_output_size=512,
|
||||||
out=out,
|
out=out,
|
||||||
|
expert_token_nums=expert_token_nums,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -177,6 +179,7 @@ class TestDisptachFFNCombine:
|
|||||||
scale2_npu.append(scale2.npu())
|
scale2_npu.append(scale2.npu())
|
||||||
|
|
||||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu()
|
||||||
|
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -189,6 +192,7 @@ class TestDisptachFFNCombine:
|
|||||||
group=self.hcomm_info,
|
group=self.hcomm_info,
|
||||||
max_output_size=512,
|
max_output_size=512,
|
||||||
out=out,
|
out=out,
|
||||||
|
expert_token_nums=expert_token_nums,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class TestDisptachFFNCombine:
|
|||||||
scale2_npu.append(scale2[i].npu())
|
scale2_npu.append(scale2[i].npu())
|
||||||
|
|
||||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu()
|
||||||
|
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -138,6 +139,7 @@ class TestDisptachFFNCombine:
|
|||||||
group=self.hcomm_info,
|
group=self.hcomm_info,
|
||||||
max_output_size=512,
|
max_output_size=512,
|
||||||
out=out,
|
out=out,
|
||||||
|
expert_token_nums=expert_token_nums,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -177,6 +179,7 @@ class TestDisptachFFNCombine:
|
|||||||
scale2_npu.append(scale2.npu())
|
scale2_npu.append(scale2.npu())
|
||||||
|
|
||||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
expert_token_nums = self.generate_random_tensor((1, e), dtype=torch.int32).npu()
|
||||||
|
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -189,6 +192,7 @@ class TestDisptachFFNCombine:
|
|||||||
group=self.hcomm_info,
|
group=self.hcomm_info,
|
||||||
max_output_size=512,
|
max_output_size=512,
|
||||||
out=out,
|
out=out,
|
||||||
|
expert_token_nums=expert_token_nums,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -315,6 +315,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
expert_tokens = None
|
expert_tokens = None
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
|
expert_token_nums = torch.zeros([self.moe_config.num_local_experts], dtype=torch.int32)
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight1=w1,
|
weight1=w1,
|
||||||
@@ -326,7 +327,9 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
max_output_size=65536,
|
max_output_size=65536,
|
||||||
out=out,
|
out=out,
|
||||||
|
expert_token_nums=expert_token_nums,
|
||||||
)
|
)
|
||||||
|
expert_tokens = expert_token_nums
|
||||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
assert expert_map is not None, "expert_map cannot be None."
|
assert expert_map is not None, "expert_map cannot be None."
|
||||||
group_list_type = 1
|
group_list_type = 1
|
||||||
|
|||||||
Reference in New Issue
Block a user