Use deep copy in Clone() (#66)

This commit is contained in:
Fangjun Kuang
2023-02-26 14:54:01 +08:00
committed by GitHub
parent 475caf22f9
commit 5a8c3a6d10
7 changed files with 41 additions and 24 deletions

View File

@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
Ort::Value Clone(const Ort::Value *v) {
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape.GetShape();
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
switch (type_and_shape.GetElementType()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
Ort::Value ans = Ort::Value::CreateTensor<int32_t>(
allocator, shape.data(), shape.size());
const int32_t *start = v->GetTensorData<int32_t>();
const int32_t *end = start + type_and_shape.GetElementCount();
int32_t *dst = ans.GetTensorMutableData<int32_t>();
std::copy(start, end, dst);
return ans;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
Ort::Value ans = Ort::Value::CreateTensor<int64_t>(
allocator, shape.data(), shape.size());
const int64_t *start = v->GetTensorData<int64_t>();
const int64_t *end = start + type_and_shape.GetElementCount();
int64_t *dst = ans.GetTensorMutableData<int64_t>();
std::copy(start, end, dst);
return ans;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, shape.data(),
shape.size());
const float *start = v->GetTensorData<float>();
const float *end = start + type_and_shape.GetElementCount();
float *dst = ans.GetTensorMutableData<float>();
std::copy(start, end, dst);
return ans;
}
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));