Use deep copy in Clone() (#66)
This commit is contained in:
@@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun initModel() {
|
private fun initModel() {
|
||||||
|
// Please change getModelConfig() to add new models
|
||||||
|
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
|
// for a list of available models
|
||||||
|
val type = 0
|
||||||
|
println("Select model type ${type}")
|
||||||
val config = OnlineRecognizerConfig(
|
val config = OnlineRecognizerConfig(
|
||||||
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
|
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
|
||||||
modelConfig = getModelConfig(type = 1)!!,
|
modelConfig = getModelConfig(type = type)!!,
|
||||||
endpointConfig = getEndpointConfig(),
|
endpointConfig = getEndpointConfig(),
|
||||||
enableEndpoint = true
|
enableEndpoint = true
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class ViewController: UIViewController {
|
|||||||
super.viewDidLoad()
|
super.viewDidLoad()
|
||||||
// Do any additional setup after loading the view.
|
// Do any additional setup after loading the view.
|
||||||
|
|
||||||
resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!"
|
resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
|
||||||
recordBtn.setTitle("Start", for: .normal)
|
recordBtn.setTitle("Start", for: .normal)
|
||||||
initRecognizer()
|
initRecognizer()
|
||||||
initRecorder()
|
initRecorder()
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ template <typename T /*=float*/>
|
|||||||
Ort::Value Cat(OrtAllocator *allocator,
|
Ort::Value Cat(OrtAllocator *allocator,
|
||||||
const std::vector<const Ort::Value *> &values, int32_t dim) {
|
const std::vector<const Ort::Value *> &values, int32_t dim) {
|
||||||
if (values.size() == 1u) {
|
if (values.size() == 1u) {
|
||||||
return Clone(values[0]);
|
return Clone(allocator, values[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> v0_shape =
|
std::vector<int64_t> v0_shape =
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
for (int32_t t = 0; t != num_frames; ++t) {
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
||||||
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
|
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
|
||||||
Ort::Value logit =
|
Ort::Value logit = model_->RunJoiner(
|
||||||
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
bool emitted = false;
|
bool emitted = false;
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value Clone(const Ort::Value *v) {
|
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
||||||
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
||||||
std::vector<int64_t> shape = type_and_shape.GetShape();
|
std::vector<int64_t> shape = type_and_shape.GetShape();
|
||||||
|
|
||||||
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
|
|||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
switch (type_and_shape.GetElementType()) {
|
switch (type_and_shape.GetElementType()) {
|
||||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
|
||||||
return Ort::Value::CreateTensor(
|
Ort::Value ans = Ort::Value::CreateTensor<int32_t>(
|
||||||
memory_info,
|
allocator, shape.data(), shape.size());
|
||||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),
|
const int32_t *start = v->GetTensorData<int32_t>();
|
||||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
const int32_t *end = start + type_and_shape.GetElementCount();
|
||||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
int32_t *dst = ans.GetTensorMutableData<int32_t>();
|
||||||
return Ort::Value::CreateTensor(
|
std::copy(start, end, dst);
|
||||||
memory_info,
|
return ans;
|
||||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),
|
}
|
||||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
|
||||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
Ort::Value ans = Ort::Value::CreateTensor<int64_t>(
|
||||||
return Ort::Value::CreateTensor(
|
allocator, shape.data(), shape.size());
|
||||||
memory_info,
|
const int64_t *start = v->GetTensorData<int64_t>();
|
||||||
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),
|
const int64_t *end = start + type_and_shape.GetElementCount();
|
||||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
int64_t *dst = ans.GetTensorMutableData<int64_t>();
|
||||||
|
std::copy(start, end, dst);
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
|
||||||
|
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, shape.data(),
|
||||||
|
shape.size());
|
||||||
|
const float *start = v->GetTensorData<float>();
|
||||||
|
const float *end = start + type_and_shape.GetElementCount();
|
||||||
|
float *dst = ans.GetTensorMutableData<float>();
|
||||||
|
std::copy(start, end, dst);
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "Unsupported type: %d\n",
|
fprintf(stderr, "Unsupported type: %d\n",
|
||||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
|||||||
void PrintModelMetadata(std::ostream &os,
|
void PrintModelMetadata(std::ostream &os,
|
||||||
const Ort::ModelMetadata &meta_data); // NOLINT
|
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||||
|
|
||||||
// Return a shallow copy of v
|
// Return a deep copy of v
|
||||||
Ort::Value Clone(const Ort::Value *v);
|
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
|
||||||
|
|
||||||
// Print a 1-D tensor to stderr
|
// Print a 1-D tensor to stderr
|
||||||
void Print1D(Ort::Value *v);
|
void Print1D(Ort::Value *v);
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
|
|||||||
int32_t n = static_cast<int32_t>(shape[dim]);
|
int32_t n = static_cast<int32_t>(shape[dim]);
|
||||||
if (n == 1) {
|
if (n == 1) {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.push_back(Clone(value));
|
ans.push_back(Clone(allocator, value));
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user