Support paraformer. (#95)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print2D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
const T *d = v->GetTensorData<T>();
|
||||
|
||||
std::ostringstream os;
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
|
||||
for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
|
||||
fprintf(stderr, "%.3f ", *d);
|
||||
os << *d << " ";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
os << "\n";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
template void Print2D<int64_t>(Ort::Value *v);
|
||||
template void Print2D<float>(Ort::Value *v);
|
||||
|
||||
void Print3D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
Reference in New Issue
Block a user