[CustomOp] support TensorList for dispatchFFNCombine (#5665)

### What this PR does / why we need it?
To support tensorList for dispatch_ffn_combine, to adjust eplb

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Single Operator Testing

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: lhchg <lhao_cheng@163.com>
Co-authored-by: lihaocheng <lihaosheng1@h-partners.com>
This commit is contained in:
lhchg
2026-01-09 15:56:29 +08:00
committed by GitHub
parent 3ce5a34468
commit dc99cfdc15
16 changed files with 293 additions and 105 deletions

View File

@@ -42,8 +42,8 @@ enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
bool transB, bool weightNz,
@@ -55,8 +55,8 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,

View File

@@ -39,8 +39,8 @@ extern "C" {
* @param [out] executor: op executor containing the operator compute flow.
* @return aclnnStatus: status code.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,

View File

@@ -24,13 +24,13 @@ class DispatchFFNCombine : public OpDef {
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("w1")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("w2")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
@@ -41,12 +41,12 @@ class DispatchFFNCombine : public OpDef {
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale1")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale2")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});

View File

@@ -91,27 +91,42 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
{
const char *nodeName = context->GetNodeName();
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");
const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0);
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);
auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum();
uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1);
uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1);
uint32_t listLen = 0;
while (true) {
auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen);
if (wTensorT == nullptr) {break;}
}
uint32_t expertPerRank;
if (listLen == 1) {
expertPerRank = wTensor->GetStorageShape().GetDim(0);
} else {
expertPerRank = listLen;
}
info.M = M;
info.N = N;
info.K = K;
info.expertPerRank = expertPerRank;
info.topK = topK;
info.listLen = listLen;
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen);
return ge::GRAPH_SUCCESS;
}