Use deep copy in Clone() (#66)

This commit is contained in:
Fangjun Kuang
2023-02-26 14:54:01 +08:00
committed by GitHub
parent 475caf22f9
commit 5a8c3a6d10
7 changed files with 41 additions and 24 deletions

View File

@@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() {
} }
private fun initModel() { private fun initModel() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 0
println("Select model type ${type}")
val config = OnlineRecognizerConfig( val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
modelConfig = getModelConfig(type = 1)!!, modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(), endpointConfig = getEndpointConfig(),
enableEndpoint = true enableEndpoint = true
) )

View File

@@ -63,7 +63,7 @@ class ViewController: UIViewController {
super.viewDidLoad() super.viewDidLoad()
// Do any additional setup after loading the view. // Do any additional setup after loading the view.
resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!" resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
recordBtn.setTitle("Start", for: .normal) recordBtn.setTitle("Start", for: .normal)
initRecognizer() initRecognizer()
initRecorder() initRecorder()

View File

@@ -37,7 +37,7 @@ template <typename T /*=float*/>
Ort::Value Cat(OrtAllocator *allocator, Ort::Value Cat(OrtAllocator *allocator,
const std::vector<const Ort::Value *> &values, int32_t dim) { const std::vector<const Ort::Value *> &values, int32_t dim) {
if (values.size() == 1u) { if (values.size() == 1u) {
return Clone(values[0]); return Clone(allocator, values[0]);
} }
std::vector<int64_t> v0_shape = std::vector<int64_t> v0_shape =

View File

@@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
for (int32_t t = 0; t != num_frames; ++t) { for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
Ort::Value logit = Ort::Value logit = model_->RunJoiner(
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
const float *p_logit = logit.GetTensorData<float>(); const float *p_logit = logit.GetTensorData<float>();
bool emitted = false; bool emitted = false;

View File

@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
} }
} }
Ort::Value Clone(const Ort::Value *v) { Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
auto type_and_shape = v->GetTensorTypeAndShapeInfo(); auto type_and_shape = v->GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape.GetShape(); std::vector<int64_t> shape = type_and_shape.GetShape();
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
switch (type_and_shape.GetElementType()) { switch (type_and_shape.GetElementType()) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
return Ort::Value::CreateTensor( Ort::Value ans = Ort::Value::CreateTensor<int32_t>(
memory_info, allocator, shape.data(), shape.size());
const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(), const int32_t *start = v->GetTensorData<int32_t>();
type_and_shape.GetElementCount(), shape.data(), shape.size()); const int32_t *end = start + type_and_shape.GetElementCount();
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: int32_t *dst = ans.GetTensorMutableData<int32_t>();
return Ort::Value::CreateTensor( std::copy(start, end, dst);
memory_info, return ans;
const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(), }
type_and_shape.GetElementCount(), shape.data(), shape.size()); case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: Ort::Value ans = Ort::Value::CreateTensor<int64_t>(
return Ort::Value::CreateTensor( allocator, shape.data(), shape.size());
memory_info, const int64_t *start = v->GetTensorData<int64_t>();
const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(), const int64_t *end = start + type_and_shape.GetElementCount();
type_and_shape.GetElementCount(), shape.data(), shape.size()); int64_t *dst = ans.GetTensorMutableData<int64_t>();
std::copy(start, end, dst);
return ans;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, shape.data(),
shape.size());
const float *start = v->GetTensorData<float>();
const float *end = start + type_and_shape.GetElementCount();
float *dst = ans.GetTensorMutableData<float>();
std::copy(start, end, dst);
return ans;
}
default: default:
fprintf(stderr, "Unsupported type: %d\n", fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType())); static_cast<int32_t>(type_and_shape.GetElementType()));

View File

@@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
void PrintModelMetadata(std::ostream &os, void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT const Ort::ModelMetadata &meta_data); // NOLINT
// Return a shallow copy of v // Return a deep copy of v
Ort::Value Clone(const Ort::Value *v); Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
// Print a 1-D tensor to stderr // Print a 1-D tensor to stderr
void Print1D(Ort::Value *v); void Print1D(Ort::Value *v);

View File

@@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
int32_t n = static_cast<int32_t>(shape[dim]); int32_t n = static_cast<int32_t>(shape[dim]);
if (n == 1) { if (n == 1) {
std::vector<Ort::Value> ans; std::vector<Ort::Value> ans;
ans.push_back(Clone(value)); ans.push_back(Clone(allocator, value));
return ans; return ans;
} }