diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
index 2ca03e80..ee2e3e67 100644
--- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
+++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
@@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() {
}
private fun initModel() {
+ // Please change getModelConfig() to add new models
+ // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
+ // for a list of available models
+ val type = 0
+ println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
- modelConfig = getModelConfig(type = 1)!!,
+ modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true
)
diff --git a/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift b/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
index 409689d2..7376a8e4 100644
--- a/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
+++ b/ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
@@ -63,7 +63,7 @@ class ViewController: UIViewController {
super.viewDidLoad()
// Do any additional setup after loading the view.
- resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!"
+ resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
recordBtn.setTitle("Start", for: .normal)
initRecognizer()
initRecorder()
diff --git a/sherpa-onnx/csrc/cat.cc b/sherpa-onnx/csrc/cat.cc
index f1938193..8bfe1a33 100644
--- a/sherpa-onnx/csrc/cat.cc
+++ b/sherpa-onnx/csrc/cat.cc
@@ -37,7 +37,7 @@ template
Ort::Value Cat(OrtAllocator *allocator,
const std::vector &values, int32_t dim) {
if (values.size() == 1u) {
- return Clone(values[0]);
+ return Clone(allocator, values[0]);
}
std::vector v0_shape =
diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
index c4b9ae15..46dbcbb3 100644
--- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
+++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
@@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
- Ort::Value logit =
- model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
+ Ort::Value logit = model_->RunJoiner(
+ std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
const float *p_logit = logit.GetTensorData();
bool emitted = false;
diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc
index 47e34359..efef86b1 100644
--- a/sherpa-onnx/csrc/onnx-utils.cc
+++ b/sherpa-onnx/csrc/onnx-utils.cc
@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
}
}
-Ort::Value Clone(const Ort::Value *v) {
+Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector shape = type_and_shape.GetShape();
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
switch (type_and_shape.GetElementType()) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
- return Ort::Value::CreateTensor(
- memory_info,
- const_cast(v)->GetTensorMutableData(),
- type_and_shape.GetElementCount(), shape.data(), shape.size());
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
- return Ort::Value::CreateTensor(
- memory_info,
- const_cast(v)->GetTensorMutableData(),
- type_and_shape.GetElementCount(), shape.data(), shape.size());
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
- return Ort::Value::CreateTensor(
- memory_info,
- const_cast(v)->GetTensorMutableData(),
- type_and_shape.GetElementCount(), shape.data(), shape.size());
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
+ Ort::Value ans = Ort::Value::CreateTensor(
+ allocator, shape.data(), shape.size());
+ const int32_t *start = v->GetTensorData();
+ const int32_t *end = start + type_and_shape.GetElementCount();
+ int32_t *dst = ans.GetTensorMutableData();
+ std::copy(start, end, dst);
+ return ans;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
+ Ort::Value ans = Ort::Value::CreateTensor(
+ allocator, shape.data(), shape.size());
+ const int64_t *start = v->GetTensorData();
+ const int64_t *end = start + type_and_shape.GetElementCount();
+ int64_t *dst = ans.GetTensorMutableData();
+ std::copy(start, end, dst);
+ return ans;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
+ Ort::Value ans = Ort::Value::CreateTensor(allocator, shape.data(),
+ shape.size());
+ const float *start = v->GetTensorData();
+ const float *end = start + type_and_shape.GetElementCount();
+ float *dst = ans.GetTensorMutableData();
+ std::copy(start, end, dst);
+ return ans;
+ }
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast(type_and_shape.GetElementType()));
diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h
index a00414d0..1ac846f3 100644
--- a/sherpa-onnx/csrc/onnx-utils.h
+++ b/sherpa-onnx/csrc/onnx-utils.h
@@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names,
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
-// Return a shallow copy of v
-Ort::Value Clone(const Ort::Value *v);
+// Return a deep copy of v
+Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
// Print a 1-D tensor to stderr
void Print1D(Ort::Value *v);
diff --git a/sherpa-onnx/csrc/unbind.cc b/sherpa-onnx/csrc/unbind.cc
index ec8c96ee..a7013698 100644
--- a/sherpa-onnx/csrc/unbind.cc
+++ b/sherpa-onnx/csrc/unbind.cc
@@ -26,7 +26,7 @@ std::vector Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t n = static_cast(shape[dim]);
if (n == 1) {
std::vector ans;
- ans.push_back(Clone(value));
+ ans.push_back(Clone(allocator, value));
return ans;
}