[CustomOp] support TensorList for dispatchFFNCombine (#5665)
### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Single Operator Testing
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
@@ -42,8 +42,8 @@ enum NnopbaseHcclServerType {
|
||||
NNOPBASE_HCCL_SERVER_TYPE_END
|
||||
};
|
||||
|
||||
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
bool transB, bool weightNz,
|
||||
@@ -55,8 +55,8 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,
|
||||
|
||||
|
||||
|
||||
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
const aclTensor* out,
|
||||
|
||||
@@ -39,8 +39,8 @@ extern "C" {
|
||||
* @param [out] executor: op executor containing the operator compute flow.
|
||||
* @return aclnnStatus: status code.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
|
||||
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
|
||||
const aclTensor* probs,
|
||||
const char* group, int64_t maxOutputSize,
|
||||
const aclTensor* out,
|
||||
|
||||
@@ -24,13 +24,13 @@ class DispatchFFNCombine : public OpDef {
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("w1")
|
||||
.ParamType(REQUIRED)
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.IgnoreContiguous();
|
||||
this->Input("w2")
|
||||
.ParamType(REQUIRED)
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
|
||||
@@ -41,12 +41,12 @@ class DispatchFFNCombine : public OpDef {
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("scale1")
|
||||
.ParamType(REQUIRED)
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("scale2")
|
||||
.ParamType(REQUIRED)
|
||||
.ParamType(DYNAMIC)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
@@ -91,27 +91,42 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
|
||||
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");
|
||||
|
||||
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
|
||||
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
|
||||
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
|
||||
auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0);
|
||||
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
|
||||
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
|
||||
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
|
||||
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
|
||||
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);
|
||||
|
||||
auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
|
||||
uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum();
|
||||
uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1);
|
||||
|
||||
uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1);
|
||||
uint32_t listLen = 0;
|
||||
while (true) {
|
||||
auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen);
|
||||
if (wTensorT == nullptr) {break;}
|
||||
}
|
||||
|
||||
uint32_t expertPerRank;
|
||||
if (listLen == 1) {
|
||||
expertPerRank = wTensor->GetStorageShape().GetDim(0);
|
||||
} else {
|
||||
expertPerRank = listLen;
|
||||
}
|
||||
|
||||
info.M = M;
|
||||
info.N = N;
|
||||
info.K = K;
|
||||
info.expertPerRank = expertPerRank;
|
||||
info.topK = topK;
|
||||
info.listLen = listLen;
|
||||
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
|
||||
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
|
||||
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
|
||||
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
|
||||
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
|
||||
OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user