Support streaming conformer CTC models from wenet (#427)
This commit is contained in:
@@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value View(Ort::Value *v) {
|
||||
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
||||
std::vector<int64_t> shape = type_and_shape.GetShape();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
switch (type_and_shape.GetElementType()) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<int32_t>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<int64_t>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
exit(-1);
|
||||
// unreachable code
|
||||
return Ort::Value{nullptr};
|
||||
}
|
||||
}
|
||||
|
||||
void Print1D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
Reference in New Issue
Block a user