[CustomOp] support TensorList for dispatchFFNCombine (#5665)
### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Single Operator Testing
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
@@ -42,8 +42,8 @@ enum NnopbaseHcclServerType {
|
|||||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||||
};
|
};
|
||||||
|
|
||||||
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
bool transB, bool weightNz,
|
bool transB, bool weightNz,
|
||||||
@@ -55,8 +55,8 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out,
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ extern "C" {
|
|||||||
* @param [out] executor: op executor containing the operator compute flow.
|
* @param [out] executor: op executor containing the operator compute flow.
|
||||||
* @return aclnnStatus: status code.
|
* @return aclnnStatus: status code.
|
||||||
*/
|
*/
|
||||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||||
const aclTensor* probs,
|
const aclTensor* probs,
|
||||||
const char* group, int64_t maxOutputSize,
|
const char* group, int64_t maxOutputSize,
|
||||||
const aclTensor* out,
|
const aclTensor* out,
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ class DispatchFFNCombine : public OpDef {
|
|||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("w1")
|
this->Input("w1")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||||
.IgnoreContiguous();
|
.IgnoreContiguous();
|
||||||
this->Input("w2")
|
this->Input("w2")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||||
@@ -41,12 +41,12 @@ class DispatchFFNCombine : public OpDef {
|
|||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("scale1")
|
this->Input("scale1")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("scale2")
|
this->Input("scale2")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
|||||||
@@ -91,27 +91,42 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
|
|||||||
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
||||||
{
|
{
|
||||||
const char *nodeName = context->GetNodeName();
|
const char *nodeName = context->GetNodeName();
|
||||||
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");
|
|
||||||
|
|
||||||
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
|
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
|
||||||
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
|
auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0);
|
||||||
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
|
|
||||||
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
|
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
|
||||||
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
|
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
|
||||||
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
|
|
||||||
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
|
auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
|
||||||
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);
|
uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum();
|
||||||
|
uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1);
|
||||||
|
|
||||||
|
uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1);
|
||||||
|
uint32_t listLen = 0;
|
||||||
|
while (true) {
|
||||||
|
auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen);
|
||||||
|
if (wTensorT == nullptr) {break;}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t expertPerRank;
|
||||||
|
if (listLen == 1) {
|
||||||
|
expertPerRank = wTensor->GetStorageShape().GetDim(0);
|
||||||
|
} else {
|
||||||
|
expertPerRank = listLen;
|
||||||
|
}
|
||||||
|
|
||||||
info.M = M;
|
info.M = M;
|
||||||
info.N = N;
|
info.N = N;
|
||||||
info.K = K;
|
info.K = K;
|
||||||
info.expertPerRank = expertPerRank;
|
info.expertPerRank = expertPerRank;
|
||||||
info.topK = topK;
|
info.topK = topK;
|
||||||
|
info.listLen = listLen;
|
||||||
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
|
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
|
||||||
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
|
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
|
||||||
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
|
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
|
||||||
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
|
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
|
||||||
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
|
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
|
||||||
|
OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen);
|
||||||
|
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ private:
|
|||||||
int32_t expertPerRank;
|
int32_t expertPerRank;
|
||||||
int32_t maxOutputSize;
|
int32_t maxOutputSize;
|
||||||
int32_t EP;
|
int32_t EP;
|
||||||
|
int32_t listLen;
|
||||||
|
|
||||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
|
||||||
uint64_t initRoutingQuantTilingKey;
|
uint64_t initRoutingQuantTilingKey;
|
||||||
@@ -138,6 +139,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM,
|
|||||||
topK = tilingData.dispatchFFNCombineInfo.topK;
|
topK = tilingData.dispatchFFNCombineInfo.topK;
|
||||||
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
|
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
|
||||||
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
|
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
|
||||||
|
listLen = tilingData.dispatchFFNCombineInfo.listLen;
|
||||||
|
|
||||||
m0 = tilingData.cocTiling.m0;
|
m0 = tilingData.cocTiling.m0;
|
||||||
k0 = tilingData.cocTiling.k0;
|
k0 = tilingData.cocTiling.k0;
|
||||||
@@ -254,7 +256,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
|
|||||||
uint32_t epilogueGranularity = expertPerRank - 1;
|
uint32_t epilogueGranularity = expertPerRank - 1;
|
||||||
|
|
||||||
typename MatmulKernel::Params params{
|
typename MatmulKernel::Params params{
|
||||||
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
|
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(listLen), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
|
||||||
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
|
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
|
||||||
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
|
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
|
||||||
epilogueCoreNum, epilogueGranularity,
|
epilogueCoreNum, epilogueGranularity,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@
|
|||||||
#include "utils/hccl_shmem.hpp"
|
#include "utils/hccl_shmem.hpp"
|
||||||
#include "utils/const_args.hpp"
|
#include "utils/const_args.hpp"
|
||||||
#include "utils/layout3d.hpp"
|
#include "utils/layout3d.hpp"
|
||||||
|
#include "utils/get_tensor_addr.hpp"
|
||||||
|
|
||||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
|
||||||
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
|
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp"
|
||||||
@@ -79,19 +80,20 @@ public:
|
|||||||
__gm__ ElementA *ptrA;
|
__gm__ ElementA *ptrA;
|
||||||
LayoutA layoutA;
|
LayoutA layoutA;
|
||||||
LayoutA layoutA2;
|
LayoutA layoutA2;
|
||||||
__gm__ ElementB *ptrB1;
|
GM_ADDR ptrB1;
|
||||||
LayoutB layoutB1;
|
LayoutB layoutB1;
|
||||||
__gm__ ElementB *ptrB2;
|
GM_ADDR ptrB2;
|
||||||
LayoutB layoutB2;
|
LayoutB layoutB2;
|
||||||
__gm__ ElementScale *ptrScale1;
|
GM_ADDR ptrScale1;
|
||||||
LayoutScale layoutScale1;
|
LayoutScale layoutScale1;
|
||||||
__gm__ ElementScale *ptrScale2;
|
GM_ADDR ptrScale2;
|
||||||
LayoutScale layoutScale2;
|
LayoutScale layoutScale2;
|
||||||
__gm__ ElementD2 *ptrOutput;
|
__gm__ ElementD2 *ptrOutput;
|
||||||
LayoutD1 layoutD1;
|
LayoutD1 layoutD1;
|
||||||
LayoutD2 layoutD2;
|
LayoutD2 layoutD2;
|
||||||
GM_ADDR ptrWorkspace;
|
GM_ADDR ptrWorkspace;
|
||||||
int32_t EP;
|
int32_t EP;
|
||||||
|
int32_t listLen;
|
||||||
int32_t expertPerRank;
|
int32_t expertPerRank;
|
||||||
uint32_t maxOutputSize;
|
uint32_t maxOutputSize;
|
||||||
uint32_t rank;
|
uint32_t rank;
|
||||||
@@ -121,7 +123,7 @@ public:
|
|||||||
CATLASS_HOST_DEVICE
|
CATLASS_HOST_DEVICE
|
||||||
Params(
|
Params(
|
||||||
GemmCoord problemShape_,
|
GemmCoord problemShape_,
|
||||||
uint32_t EP_, uint32_t expertPerRank_, uint32_t maxOutputSize_,
|
uint32_t EP_, uint32_t listLen_, uint32_t expertPerRank_, uint32_t maxOutputSize_,
|
||||||
uint32_t rank_, uint32_t rankSize_, int64_t topK_,
|
uint32_t rank_, uint32_t rankSize_, int64_t topK_,
|
||||||
uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_,
|
uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_,
|
||||||
GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_,
|
GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_,
|
||||||
@@ -136,15 +138,15 @@ public:
|
|||||||
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
|
||||||
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
|
||||||
) : problemShape(problemShape_),
|
) : problemShape(problemShape_),
|
||||||
EP(EP_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
|
||||||
rank(rank_), rankSize(rankSize_), topK(topK_),
|
rank(rank_), rankSize(rankSize_), topK(topK_),
|
||||||
initRoutingQuantTilingKey(initRoutingQuantTilingKey_),
|
initRoutingQuantTilingKey(initRoutingQuantTilingKey_),
|
||||||
epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_),
|
epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_),
|
||||||
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_),
|
ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_),
|
||||||
ptrB1(reinterpret_cast<__gm__ ElementB *>(ptrB1_)), layoutB1(layoutB1_),
|
ptrB1(ptrB1_), layoutB1(layoutB1_),
|
||||||
ptrB2(reinterpret_cast<__gm__ ElementB *>(ptrB2_)), layoutB2(layoutB2_),
|
ptrB2(ptrB2_), layoutB2(layoutB2_),
|
||||||
ptrScale1(reinterpret_cast<__gm__ ElementScale *>(ptrScale1_)), layoutScale1(layoutScale1_),
|
ptrScale1(ptrScale1_), layoutScale1(layoutScale1_),
|
||||||
ptrScale2(reinterpret_cast<__gm__ ElementScale *>(ptrScale2_)), layoutScale2(layoutScale2_),
|
ptrScale2(ptrScale2_), layoutScale2(layoutScale2_),
|
||||||
ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_),
|
ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_),
|
||||||
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
|
||||||
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
|
||||||
@@ -212,11 +214,9 @@ private:
|
|||||||
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
|
cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM));
|
||||||
|
|
||||||
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
|
gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA));
|
||||||
gmS.SetGlobalBuffer(params.ptrScale1);
|
|
||||||
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
|
gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC));
|
||||||
|
|
||||||
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
|
gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken));
|
||||||
gmS2.SetGlobalBuffer(params.ptrScale2);
|
|
||||||
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
|
gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2));
|
||||||
|
|
||||||
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
|
gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale));
|
||||||
@@ -224,7 +224,7 @@ private:
|
|||||||
|
|
||||||
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
|
||||||
|
|
||||||
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank + 8, params.expertPerRank);
|
tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@@ -291,7 +291,7 @@ private:
|
|||||||
AscendC::DataCopyPad(
|
AscendC::DataCopyPad(
|
||||||
tmpBuffer1,
|
tmpBuffer1,
|
||||||
tokenPerExpert[rankId * expertPerRank],
|
tokenPerExpert[rankId * expertPerRank],
|
||||||
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank + 8) * sizeof(int32_t)), 0},
|
{U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0},
|
||||||
{}
|
{}
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -327,6 +327,18 @@ private:
|
|||||||
AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul
|
AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
|
|
||||||
|
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
|
||||||
|
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
|
||||||
|
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
|
||||||
|
|
||||||
|
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
||||||
|
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
||||||
|
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
|
||||||
|
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
|
||||||
|
}
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
if (preCurrentmSum >= params.maxOutputSize) {
|
if (preCurrentmSum >= params.maxOutputSize) {
|
||||||
@@ -335,7 +347,13 @@ private:
|
|||||||
currentM = params.maxOutputSize - preCurrentmSum;
|
currentM = params.maxOutputSize - preCurrentmSum;
|
||||||
}
|
}
|
||||||
AscendC::GlobalTensor<ElementB> gmB1;
|
AscendC::GlobalTensor<ElementB> gmB1;
|
||||||
gmB1.SetGlobalBuffer(params.ptrB1);
|
AscendC::GlobalTensor<ElementScale> gmS;
|
||||||
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
|
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
|
||||||
|
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
|
||||||
|
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
}
|
}
|
||||||
@@ -364,7 +382,7 @@ private:
|
|||||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
|
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // One scale group per expert
|
int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * params.problemShape.n() : 0);
|
||||||
if (currentM > 0) {
|
if (currentM > 0) {
|
||||||
blockMmad(
|
blockMmad(
|
||||||
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
|
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||||
@@ -386,7 +404,9 @@ private:
|
|||||||
|
|
||||||
preCurrentmSum += currentM;
|
preCurrentmSum += currentM;
|
||||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||||
|
if (params.listLen == 1) {
|
||||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||||
|
}
|
||||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||||
}
|
}
|
||||||
@@ -420,6 +440,17 @@ private:
|
|||||||
if (params.epilogueGranularity < params.expertPerRank) {
|
if (params.epilogueGranularity < params.expertPerRank) {
|
||||||
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr uint32_t MAX_EXPERTS_PER_RANK = 8;
|
||||||
|
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
|
||||||
|
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
|
||||||
|
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
|
||||||
|
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
|
||||||
|
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
|
||||||
|
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
|
||||||
|
}
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
|
||||||
if (preCurrentmSum >= params.maxOutputSize) {
|
if (preCurrentmSum >= params.maxOutputSize) {
|
||||||
@@ -428,7 +459,12 @@ private:
|
|||||||
currentM = params.maxOutputSize - preCurrentmSum;
|
currentM = params.maxOutputSize - preCurrentmSum;
|
||||||
}
|
}
|
||||||
AscendC::GlobalTensor<ElementB> gmB2;
|
AscendC::GlobalTensor<ElementB> gmB2;
|
||||||
gmB2.SetGlobalBuffer(params.ptrB2);
|
AscendC::GlobalTensor<ElementScale> gmS2;
|
||||||
|
AscendC::PipeBarrier<PIPE_ALL>();
|
||||||
|
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
|
||||||
|
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
|
||||||
|
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
|
||||||
|
|
||||||
if (currentM <= L1TileShape::M) {
|
if (currentM <= L1TileShape::M) {
|
||||||
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
|
||||||
}
|
}
|
||||||
@@ -465,7 +501,7 @@ private:
|
|||||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
|
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // One scale group per expert
|
int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * n2 : 0); // One scale group per expert
|
||||||
if (currentM > 0) {
|
if (currentM > 0) {
|
||||||
blockMmad(
|
blockMmad(
|
||||||
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
|
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||||
@@ -478,7 +514,9 @@ private:
|
|||||||
}
|
}
|
||||||
preCurrentmSum += currentM;
|
preCurrentmSum += currentM;
|
||||||
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k();
|
||||||
|
if (params.listLen == 1) {
|
||||||
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n();
|
||||||
|
}
|
||||||
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n();
|
||||||
|
|
||||||
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
startCoreIdx = (startCoreIdx + coreLoops) % coreNum;
|
||||||
@@ -491,12 +529,29 @@ private:
|
|||||||
blockMmad.Finalize(params.expertPerRank - 1, 3);
|
blockMmad.Finalize(params.expertPerRank - 1, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATLASS_DEVICE
|
||||||
|
void ResetTokenPerExpert(AscendC::GlobalTensor<int32_t> & tokenPerExpert, int32_t num)
|
||||||
|
{
|
||||||
|
if (coreIdx != coreNum - 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(EVENT_ID0);
|
||||||
|
AscendC::LocalTensor<int32_t> tmp = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||||
|
AscendC::Duplicate(tmp, 0, num);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::DataCopy(tokenPerExpert, tmp, num);
|
||||||
|
}
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){
|
||||||
uint64_t flag_offset = (shmem.SegmentSize() - MB_SIZE) / sizeof(int32_t);
|
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
||||||
__gm__ int32_t* sync_base = shmem.SyncBaseAddr();
|
uint32_t numPerCore = params.EP * params.expertPerRank;
|
||||||
int count = gm_load(sync_base) + 1;
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
if (coreIdx < params.EP && coreIdx != params.rank) {
|
if (dstEpIdx == params.rank) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
AscendC::GlobalTensor<int32_t> srcAddress;
|
AscendC::GlobalTensor<int32_t> srcAddress;
|
||||||
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset));
|
||||||
AscendC::GlobalTensor<int32_t> dstAddress;
|
AscendC::GlobalTensor<int32_t> dstAddress;
|
||||||
@@ -509,27 +564,42 @@ private:
|
|||||||
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
using CopyUbToGm = Epilogue::Tile::CopyUb2Gm<ArchTag, TType>;
|
||||||
CopyGmToUb copyGmToUb;
|
CopyGmToUb copyGmToUb;
|
||||||
CopyUbToGm copyUbToGm;
|
CopyUbToGm copyUbToGm;
|
||||||
AscendC::LocalTensor<int32_t> tmpBuffer = resource.ubBuf.template GetBufferByByte<int32_t>(0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
|
||||||
uint32_t tmp = params.EP * params.expertPerRank;
|
|
||||||
copyGmToUb(tmpBuffer, srcAddress[0],
|
|
||||||
layout::RowMajor{ 1, tmp},
|
|
||||||
layout::RowMajor{1, tmp});
|
|
||||||
|
|
||||||
tmpBuffer.SetValue(params.EP * params.expertPerRank, count);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(EVENT_ID0);
|
copyGmToUb(tmpBuffer, srcAddress[0],
|
||||||
|
layout::RowMajor{ 1, numPerCore},
|
||||||
|
layout::RowMajor{1, numPerCore});
|
||||||
|
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
copyUbToGm(dstAddress[0], tmpBuffer,
|
copyUbToGm(dstAddress[0], tmpBuffer,
|
||||||
layout::RowMajor{ 1, tmp + 1},
|
layout::RowMajor{ 1, numPerCore},
|
||||||
layout::RowMajor{1, tmp + 1});
|
layout::RowMajor{1, numPerCore});
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
|
||||||
|
}
|
||||||
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(coreIdx, params.EP, 0);
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
gm_signal_wait_until_eq_for_barrier(sync_check, count);
|
if (dstEpIdx == params.rank) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int32_t intPer512 = CACHE_LINE / sizeof(int);
|
||||||
|
for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) {
|
||||||
|
__gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx);
|
||||||
|
gm_signal_wait_until_ne(sync_check, 0);
|
||||||
|
}
|
||||||
|
AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(EVENT_ID0);
|
||||||
|
AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(EVENT_ID0);
|
||||||
|
AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
gm_store(sync_base, count);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -569,7 +639,8 @@ private:
|
|||||||
uint32_t prevGroupSum1 = 0;
|
uint32_t prevGroupSum1 = 0;
|
||||||
uint32_t dequantSum = 0;
|
uint32_t dequantSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
BlockEpilogue1 blockEpilogue(resource);
|
uint32_t n = params.problemShape.n();
|
||||||
|
BlockEpilogue1 blockEpilogue(resource, n);
|
||||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
// The ith core reads data from the ith rank's peermem
|
// The ith core reads data from the ith rank's peermem
|
||||||
groupIdxDeq = groupIdx - 2;
|
groupIdxDeq = groupIdx - 2;
|
||||||
@@ -668,7 +739,8 @@ private:
|
|||||||
typename BlockEpilogue2::Params epilogueParams{
|
typename BlockEpilogue2::Params epilogueParams{
|
||||||
static_cast<int32_t>(params.EP),
|
static_cast<int32_t>(params.EP),
|
||||||
static_cast<int32_t>(params.expertPerRank),
|
static_cast<int32_t>(params.expertPerRank),
|
||||||
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace)
|
reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace),
|
||||||
|
static_cast<int32_t>(n2)
|
||||||
};
|
};
|
||||||
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
BlockEpilogue2 blockEpilogue(resource, epilogueParams);
|
||||||
int32_t prevGroupSum2 = 0;
|
int32_t prevGroupSum2 = 0;
|
||||||
@@ -704,6 +776,7 @@ private:
|
|||||||
}
|
}
|
||||||
blockEpilogue.Finalize();
|
blockEpilogue.Finalize();
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
|
ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank);
|
||||||
shmem.CrossRankSync();
|
shmem.CrossRankSync();
|
||||||
MoeTokenUnpermuteTilingData tilingData;
|
MoeTokenUnpermuteTilingData tilingData;
|
||||||
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum);
|
MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum);
|
||||||
@@ -794,10 +867,8 @@ private:
|
|||||||
|
|
||||||
AscendC::GlobalTensor<ElementA> gmA;
|
AscendC::GlobalTensor<ElementA> gmA;
|
||||||
AscendC::GlobalTensor<ElementC> gmC;
|
AscendC::GlobalTensor<ElementC> gmC;
|
||||||
AscendC::GlobalTensor<ElementScale> gmS;
|
|
||||||
|
|
||||||
AscendC::GlobalTensor<ElementD1> gmPermutedToken;
|
AscendC::GlobalTensor<ElementD1> gmPermutedToken;
|
||||||
AscendC::GlobalTensor<ElementScale> gmS2;
|
|
||||||
AscendC::GlobalTensor<ElementC> gmC2;
|
AscendC::GlobalTensor<ElementC> gmC2;
|
||||||
|
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale1;
|
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale1;
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ struct DispatchFFNCombineInfo {
|
|||||||
uint32_t totalUbSize;
|
uint32_t totalUbSize;
|
||||||
uint32_t topK;
|
uint32_t topK;
|
||||||
uint32_t worldSize;
|
uint32_t worldSize;
|
||||||
|
uint32_t listLen;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CoCTiling {
|
struct CoCTiling {
|
||||||
|
|||||||
@@ -70,23 +70,24 @@ public:
|
|||||||
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
__gm__ int32_t *ptrTokenPerExpert{nullptr};
|
||||||
int32_t EP;
|
int32_t EP;
|
||||||
int32_t expertPerRank;
|
int32_t expertPerRank;
|
||||||
|
int32_t n2;
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
Params() {};
|
Params() {};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {}
|
Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_, int32_t n2_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_), n2(n2_) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
||||||
{
|
{
|
||||||
size_t ubOffset = 4096;
|
size_t ubOffset = 0;
|
||||||
int32_t eventVMTE2 = 0;
|
int32_t eventVMTE2 = 0;
|
||||||
int32_t eventMTE2V = 0;
|
int32_t eventMTE2V = 0;
|
||||||
int32_t eventMTE3V = 0;
|
int32_t eventMTE3V = 0;
|
||||||
int32_t eventVMTE3 = 0;
|
int32_t eventVMTE3 = 0;
|
||||||
constexpr int32_t blockN = 12000;
|
int32_t blockN = params.n2;
|
||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||||
ubOffset += blockN * sizeof(ElementC);
|
ubOffset += blockN * sizeof(ElementC);
|
||||||
|
|||||||
@@ -84,16 +84,16 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
BlockEpilogue(Arch::Resource<ArchTag> const &resource, int32_t n, Params const ¶ms = Params{}) : params(params)
|
||||||
{
|
{
|
||||||
size_t ubOffset = 0;
|
size_t ubOffset = 0;
|
||||||
int32_t eventVMTE2 = 0;
|
int32_t eventVMTE2 = 0;
|
||||||
int32_t eventMTE2V = 0;
|
int32_t eventMTE2V = 0;
|
||||||
int32_t eventMTE3V = 0;
|
int32_t eventMTE3V = 0;
|
||||||
int32_t eventVMTE3 = 0;
|
int32_t eventVMTE3 = 0;
|
||||||
constexpr uint32_t blockN = 4096;
|
uint32_t blockN = n;
|
||||||
constexpr uint32_t ChunkTileLen = blockN / 2;
|
uint32_t ChunkTileLen = blockN / 2;
|
||||||
constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
uint32_t HalfChunkTileLen = ChunkTileLen / 2;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||||
|
|||||||
@@ -3,4 +3,6 @@
|
|||||||
#define CONST_ARGS_HPP
|
#define CONST_ARGS_HPP
|
||||||
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
constexpr static uint64_t MB_SIZE = 1024 * 1024UL;
|
||||||
constexpr static int32_t NUMS_PER_FLAG = 16;
|
constexpr static int32_t NUMS_PER_FLAG = 16;
|
||||||
|
constexpr static int32_t CACHE_LINE = 512;
|
||||||
|
constexpr static int32_t RESET_VAL = 0xffff;
|
||||||
#endif
|
#endif
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
#ifndef GET_TENSOR_ADDR_HPP
|
||||||
|
#define GET_TENSOR_ADDR_HPP
|
||||||
|
#include "kernel_operator.h"
|
||||||
|
|
||||||
|
#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) {
|
||||||
|
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
|
||||||
|
uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address.
|
||||||
|
// Moving 3 bits to the right means dividing by sizeof(uint64 t).
|
||||||
|
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
|
||||||
|
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // GET_TENSOR_ADDR_HPP
|
||||||
@@ -53,13 +53,30 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) {
|
||||||
|
do {
|
||||||
|
AscendC::LocalTensor<int32_t> ub;
|
||||||
|
ub.address_.logicPos = static_cast<uint8_t>(TPosition::VECIN);
|
||||||
|
ub.address_.bufferAddr = 0;
|
||||||
|
AscendC::GlobalTensor<int32_t> sig;
|
||||||
|
sig.SetGlobalBuffer(sig_addr);
|
||||||
|
AscendC::DataCopy(ub, sig, 8);
|
||||||
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||||
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(EVENT_ID0);
|
||||||
|
if (ub(0) != cmp_val) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||||
class HcclShmem {
|
class HcclShmem {
|
||||||
public:
|
public:
|
||||||
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
|
||||||
size_t m_segmentSize;
|
size_t m_segmentSize;
|
||||||
int32_t m_rank;
|
int32_t m_rank;
|
||||||
int32_t m_rankSize;
|
int32_t m_rankSize;
|
||||||
@@ -73,11 +90,6 @@ public:
|
|||||||
m_rankSize = WinContext_->rankSize;
|
m_rankSize = WinContext_->rankSize;
|
||||||
m_segmentSize = WinContext_->winSize;
|
m_segmentSize = WinContext_->winSize;
|
||||||
|
|
||||||
for (int i = 0; i < m_rankSize; i++) {
|
|
||||||
m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn :
|
|
||||||
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
@@ -94,7 +106,7 @@ public:
|
|||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() () const { // No argument: return local peermem
|
GM_ADDR operator() () const { // No argument: return local peermem
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[m_rank];
|
return (GM_ADDR)(WinContext_->localWindowsIn);
|
||||||
#else
|
#else
|
||||||
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
|
return reinterpret_cast<GM_ADDR>(shmemi_get_state()->heap_base);
|
||||||
#endif
|
#endif
|
||||||
@@ -103,7 +115,8 @@ public:
|
|||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
|
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[index];
|
return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn :
|
||||||
|
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn);
|
||||||
#else
|
#else
|
||||||
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
|
return reinterpret_cast<GM_ADDR>(shmem_ptr(shmemi_get_state()->heap_base, index));
|
||||||
#endif
|
#endif
|
||||||
@@ -120,7 +133,8 @@ public:
|
|||||||
if (rankId < 0 || rankId >= m_rankSize) {
|
if (rankId < 0 || rankId >= m_rankSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return m_ptrArray[rankId] + offset;
|
return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn :
|
||||||
|
((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset;
|
||||||
#else
|
#else
|
||||||
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
|
return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -727,11 +727,11 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
|||||||
|
|
||||||
at::Tensor& dispatch_ffn_combine(
|
at::Tensor& dispatch_ffn_combine(
|
||||||
const at::Tensor& x,
|
const at::Tensor& x,
|
||||||
const at::Tensor& weight1,
|
const at::TensorList& weight1,
|
||||||
const at::Tensor& weight2,
|
const at::TensorList& weight2,
|
||||||
const at::Tensor& expert_idx,
|
const at::Tensor& expert_idx,
|
||||||
const at::Tensor& scale1,
|
const at::TensorList& scale1,
|
||||||
const at::Tensor& scale2,
|
const at::TensorList& scale2,
|
||||||
const at::Tensor& probs,
|
const at::Tensor& probs,
|
||||||
c10::string_view group,
|
c10::string_view group,
|
||||||
int64_t max_output_size,
|
int64_t max_output_size,
|
||||||
@@ -1383,8 +1383,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
||||||
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx,"
|
"dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx,"
|
||||||
" Tensor scale1, Tensor scale2, Tensor probs, str group,"
|
" Tensor[] scale1, Tensor[] scale2, Tensor probs, str group,"
|
||||||
" int max_output_size, Tensor! out) -> Tensor"
|
" int max_output_size, Tensor! out) -> Tensor"
|
||||||
);
|
);
|
||||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||||
|
|||||||
@@ -196,11 +196,11 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
|||||||
|
|
||||||
at::Tensor& dispatch_ffn_combine_meta(
|
at::Tensor& dispatch_ffn_combine_meta(
|
||||||
const at::Tensor& x,
|
const at::Tensor& x,
|
||||||
const at::Tensor& weight1,
|
const at::TensorList& weight1,
|
||||||
const at::Tensor& weight2,
|
const at::TensorList& weight2,
|
||||||
const at::Tensor& expert_idx,
|
const at::Tensor& expert_idx,
|
||||||
const at::Tensor& scale1,
|
const at::TensorList& scale1,
|
||||||
const at::Tensor& scale2,
|
const at::TensorList& scale2,
|
||||||
const at::Tensor& probs,
|
const at::Tensor& probs,
|
||||||
c10::string_view group,
|
c10::string_view group,
|
||||||
int64_t max_output_size,
|
int64_t max_output_size,
|
||||||
|
|||||||
@@ -87,13 +87,13 @@ class TestDisptachFFNCombine:
|
|||||||
hcomm_info = hcomm_info_dist["default_pg_info"]
|
hcomm_info = hcomm_info_dist["default_pg_info"]
|
||||||
self.hcomm_info = hcomm_info
|
self.hcomm_info = hcomm_info
|
||||||
|
|
||||||
def run_npu_out(self) -> bool:
|
def run_tensor_list(self) -> bool:
|
||||||
torch_npu.npu.set_device(self.rank)
|
torch_npu.npu.set_device(self.rank)
|
||||||
m = 2 # token-num 32
|
m = 64
|
||||||
k = 4 # hidden_size 7168
|
k = 1024
|
||||||
n = 4 # mid-hidden-size 4096
|
n = 1024
|
||||||
topk = 2
|
topk = 8
|
||||||
e = 2 # expert-num-per-rank 16
|
e = 8
|
||||||
k2 = n // 2
|
k2 = n // 2
|
||||||
n2 = k
|
n2 = k
|
||||||
|
|
||||||
@@ -112,15 +112,79 @@ class TestDisptachFFNCombine:
|
|||||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||||
|
|
||||||
|
weight1_nz_npu = []
|
||||||
|
weight2_nz_npu = []
|
||||||
|
scale1_npu = []
|
||||||
|
scale2_npu = []
|
||||||
|
for i in range(e):
|
||||||
|
weight1_nz_npu.append(
|
||||||
|
torch_npu.npu_format_cast(weight1[i].npu(), 29))
|
||||||
|
scale1_npu.append(scale1[i].npu())
|
||||||
|
weight2_nz_npu.append(
|
||||||
|
torch_npu.npu_format_cast(weight2[i].npu(), 29))
|
||||||
|
scale2_npu.append(scale2[i].npu())
|
||||||
|
|
||||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=x,
|
x=x,
|
||||||
weight1=weight1,
|
weight1=weight1_nz_npu,
|
||||||
weight2=weight2,
|
weight2=weight2_nz_npu,
|
||||||
expert_idx=expert_idx,
|
expert_idx=expert_idx,
|
||||||
scale1=scale1,
|
scale1=scale1_npu,
|
||||||
scale2=scale2,
|
scale2=scale2_npu,
|
||||||
|
probs=probs,
|
||||||
|
group=self.hcomm_info,
|
||||||
|
max_output_size=512,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run_normal(self) -> bool:
|
||||||
|
torch_npu.npu.set_device(self.rank)
|
||||||
|
m = 64
|
||||||
|
k = 1024
|
||||||
|
n = 1024
|
||||||
|
topk = 8
|
||||||
|
e = 8
|
||||||
|
k2 = n // 2
|
||||||
|
n2 = k
|
||||||
|
|
||||||
|
torch_npu.npu.config.allow_internal_format = True
|
||||||
|
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
weight1 = self.generate_random_tensor((e, k, n),
|
||||||
|
dtype=torch.int8).npu()
|
||||||
|
weight1 = torch_npu.npu_format_cast(weight1, 29)
|
||||||
|
weight2 = self.generate_random_tensor((e, k2, n2),
|
||||||
|
dtype=torch.int8).npu()
|
||||||
|
weight2 = torch_npu.npu_format_cast(weight2, 29)
|
||||||
|
|
||||||
|
expert_idx = torch.randint(0,
|
||||||
|
self.world_size * e, (m, topk),
|
||||||
|
dtype=torch.int32).npu()
|
||||||
|
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||||
|
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||||
|
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||||
|
|
||||||
|
weight1_nz_npu = []
|
||||||
|
weight2_nz_npu = []
|
||||||
|
scale1_npu = []
|
||||||
|
scale2_npu = []
|
||||||
|
weight1_nz_npu.append(torch_npu.npu_format_cast(weight1.npu(), 29))
|
||||||
|
scale1_npu.append(scale1.npu())
|
||||||
|
weight2_nz_npu.append(torch_npu.npu_format_cast(weight2.npu(), 29))
|
||||||
|
scale2_npu.append(scale2.npu())
|
||||||
|
|
||||||
|
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||||
|
|
||||||
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
|
x=x,
|
||||||
|
weight1=weight1_nz_npu,
|
||||||
|
weight2=weight2_nz_npu,
|
||||||
|
expert_idx=expert_idx,
|
||||||
|
scale1=scale1_npu,
|
||||||
|
scale2=scale2_npu,
|
||||||
probs=probs,
|
probs=probs,
|
||||||
group=self.hcomm_info,
|
group=self.hcomm_info,
|
||||||
max_output_size=512,
|
max_output_size=512,
|
||||||
@@ -142,8 +206,10 @@ class TestDisptachFFNCombine:
|
|||||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
op = TestDisptachFFNCombine(rank, world_size, port)
|
||||||
op.generate_hcom()
|
op.generate_hcom()
|
||||||
out = op.run_npu_out()
|
out1 = op.run_tensor_list()
|
||||||
q.put(out)
|
q.put(out1)
|
||||||
|
out2 = op.run_normal()
|
||||||
|
q.put(out2)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@@ -306,11 +306,11 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight1=w1[0],
|
weight1=w1,
|
||||||
weight2=w2[0],
|
weight2=w2,
|
||||||
expert_idx=topk_ids,
|
expert_idx=topk_ids,
|
||||||
scale1=w1_scale[0],
|
scale1=w1_scale,
|
||||||
scale2=w2_scale[0],
|
scale2=w2_scale,
|
||||||
probs=topk_weights.to(torch.float32),
|
probs=topk_weights.to(torch.float32),
|
||||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
max_output_size=65536,
|
max_output_size=65536,
|
||||||
|
|||||||
Reference in New Issue
Block a user