diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 2ca03e80..ee2e3e67 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() { } private fun initModel() { + // Please change getModelConfig() to add new models + // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // for a list of available models + val type = 0 + println("Select model type ${type}") val config = OnlineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), - modelConfig = getModelConfig(type = 1)!!, + modelConfig = getModelConfig(type = type)!!, endpointConfig = getEndpointConfig(), enableEndpoint = true ) diff --git a/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift b/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift index 409689d2..7376a8e4 100644 --- a/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift +++ b/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift @@ -63,7 +63,7 @@ class ViewController: UIViewController { super.viewDidLoad() // Do any additional setup after loading the view. - resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!" + resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" recordBtn.setTitle("Start", for: .normal) initRecognizer() initRecorder() diff --git a/sherpa-onnx/csrc/cat.cc b/sherpa-onnx/csrc/cat.cc index f1938193..8bfe1a33 100644 --- a/sherpa-onnx/csrc/cat.cc +++ b/sherpa-onnx/csrc/cat.cc @@ -37,7 +37,7 @@ template Ort::Value Cat(OrtAllocator *allocator, const std::vector &values, int32_t dim) { if (values.size() == 1u) { - return Clone(values[0]); + return Clone(allocator, values[0]); } std::vector v0_shape = diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index c4b9ae15..46dbcbb3 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( for (int32_t t = 0; t != num_frames; ++t) { Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); - Ort::Value logit = - model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); + Ort::Value logit = model_->RunJoiner( + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); const float *p_logit = logit.GetTensorData(); bool emitted = false; diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 47e34359..efef86b1 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { } } -Ort::Value Clone(const Ort::Value *v) { +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { auto type_and_shape = v->GetTensorTypeAndShapeInfo(); std::vector shape = type_and_shape.GetShape(); @@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) { Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); switch (type_and_shape.GetElementType()) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return Ort::Value::CreateTensor( - memory_info, - const_cast(v)->GetTensorMutableData(), - type_and_shape.GetElementCount(), shape.data(), shape.size()); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return Ort::Value::CreateTensor( - memory_info, - const_cast(v)->GetTensorMutableData(), - type_and_shape.GetElementCount(), shape.data(), shape.size()); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return Ort::Value::CreateTensor( - memory_info, - const_cast(v)->GetTensorMutableData(), - type_and_shape.GetElementCount(), shape.data(), shape.size()); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + Ort::Value ans = Ort::Value::CreateTensor( + allocator, shape.data(), shape.size()); + const int32_t *start = v->GetTensorData(); + const int32_t *end = start + type_and_shape.GetElementCount(); + int32_t *dst = ans.GetTensorMutableData(); + std::copy(start, end, dst); + return ans; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + Ort::Value ans = Ort::Value::CreateTensor( + allocator, shape.data(), shape.size()); + const int64_t *start = v->GetTensorData(); + const int64_t *end = start + type_and_shape.GetElementCount(); + int64_t *dst = ans.GetTensorMutableData(); + std::copy(start, end, dst); + return ans; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + Ort::Value ans = Ort::Value::CreateTensor(allocator, shape.data(), + shape.size()); + const float *start = v->GetTensorData(); + const float *end = start + type_and_shape.GetElementCount(); + float *dst = ans.GetTensorMutableData(); + std::copy(start, end, dst); + return ans; + } default: fprintf(stderr, "Unsupported type: %d\n", static_cast(type_and_shape.GetElementType())); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index a00414d0..1ac846f3 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names, void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT -// Return a shallow copy of v -Ort::Value Clone(const Ort::Value *v); +// Return a deep copy of v +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); // Print a 1-D tensor to stderr void Print1D(Ort::Value *v); diff --git a/sherpa-onnx/csrc/unbind.cc b/sherpa-onnx/csrc/unbind.cc index ec8c96ee..a7013698 100644 --- a/sherpa-onnx/csrc/unbind.cc +++ b/sherpa-onnx/csrc/unbind.cc @@ -26,7 +26,7 @@ std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value, int32_t n = static_cast(shape[dim]); if (n == 1) { std::vector ans; - ans.push_back(Clone(value)); + ans.push_back(Clone(allocator, value)); return ans; }