Add Streaming zipformer (#50)

This commit is contained in:
Fangjun Kuang
2023-02-21 20:00:03 +08:00
committed by GitHub
parent 8d1be0945e
commit 3ea6aa949d
20 changed files with 1576 additions and 99 deletions

View File

@@ -46,16 +46,74 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
Ort::Value Clone(Ort::Value *v) {
Ort::Value Clone(const Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape.GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(),
shape.data(), shape.size());
switch (type_and_shape.GetElementType()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return Ort::Value::CreateTensor(
memory_info,
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));
exit(-1);
// unreachable code
return Ort::Value{nullptr};
}
}
void Print1D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
fprintf(stderr, "%.3f ", d[i]);
}
fprintf(stderr, "\n");
}
void Print2D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
fprintf(stderr, "%.3f ", *d);
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");
}
void Print3D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
fprintf(stderr, "---plane %d---\n", p);
for (int32_t r = 0; r != static_cast<int32_t>(shape[1]); ++r) {
for (int32_t c = 0; c != static_cast<int32_t>(shape[2]); ++c, ++d) {
fprintf(stderr, "%.3f ", *d);
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "\n");
}
} // namespace sherpa_onnx