diff --git a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp index 5e1c1144..c4c11a4e 100644 --- a/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp +++ b/csrc/transpose_kv_cache_by_block/op_host/transpose_kv_cache_by_block_def.cpp @@ -17,9 +17,9 @@ public: .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("blockIDs") .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT64}) - .FormatList({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}); + .DataType({ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("blockSize").Int(); this->Attr("headNum").Int(); this->Attr("headDim").Int();