diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 5eb1fb49..f25e8bf6 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -6,11 +6,12 @@ set(sources features.cc file-utils.cc hypothesis.cc + offline-recognizer-impl.cc + offline-recognizer.cc offline-stream.cc offline-transducer-greedy-search-decoder.cc offline-transducer-model-config.cc offline-transducer-model.cc - offline-recognizer.cc online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 5630add2..a44021e7 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -23,36 +23,55 @@ } while (0) #endif +// Read an integer #define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ do { \ auto value = \ meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ if (!value) { \ - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ exit(-1); \ } \ \ dst = atoi(value.get()); \ if (dst <= 0) { \ - fprintf(stderr, "Invalid value %d for %s\n", dst, src_key); \ + SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ exit(-1); \ } \ } while (0) -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ - exit(-1); \ - } \ - \ - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ - if (!ret) { \ - fprintf(stderr, "Invalid value %s for %s\n", value.get(), src_key); \ - exit(-1); \ - } \ +// read a vector of integers +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ + exit(-1); \ + } \ + } while (0) + +// Read a string +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + dst = value.get(); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for %s\n", src_key); \ + exit(-1); \ + } \ } while (0) #endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc new file mode 100644 index 00000000..cfcad26f --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/csrc/offline-recognizer-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +std::unique_ptr OfflineRecognizerImpl::Create( + const OfflineRecognizerConfig &config) { + Ort::Env env; + + Ort::SessionOptions sess_opts; + auto buf = ReadFile(config.model_config.encoder_filename); + + auto encoder_sess = + std::make_unique(env, buf.data(), buf.size(), sess_opts); + + Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata(); + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + + if (model_type == "conformer") { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str()); + + exit(-1); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h new file mode 100644 index 00000000..065be58e --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-recognizer-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-stream.h" + +namespace sherpa_onnx { + +class OfflineRecognizerImpl { + public: + static std::unique_ptr Create( + const OfflineRecognizerConfig &config); + + virtual ~OfflineRecognizerImpl() = default; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h new file mode 100644 index 00000000..b9884e1b --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -0,0 +1,134 @@ +// sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +// +// Copyright (c) 2022 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert( + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + return r; +} + +class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerTransducerImpl( + const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else if (config_.decoding_method == "modified_beam_search") { + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); + exit(-1); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector features; + + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + features_length_vec[i] = num_frames; + features_vec[i] = std::move(f); + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, + -23.025850929940457f); + + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + + ss[i]->SetResult(r); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 30a4154a..873a5f3b 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -5,42 +5,11 @@ #include "sherpa-onnx/csrc/offline-recognizer.h" #include -#include -#include "sherpa-onnx/csrc/macros.h" -#include "sherpa-onnx/csrc/offline-transducer-decoder.h" -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" -#include "sherpa-onnx/csrc/offline-transducer-model.h" -#include "sherpa-onnx/csrc/pad-sequence.h" -#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" namespace sherpa_onnx { -static OfflineRecognitionResult Convert( - const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, - int32_t frame_shift_ms, int32_t subsampling_factor) { - OfflineRecognitionResult r; - r.tokens.reserve(src.tokens.size()); - r.timestamps.reserve(src.timestamps.size()); - - std::string text; - for (auto i : src.tokens) { - auto sym = sym_table[i]; - text.append(sym); - - r.tokens.push_back(std::move(sym)); - } - r.text = std::move(text); - - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; - for (auto t : src.timestamps) { - float time = frame_shift_s * t; - r.timestamps.push_back(time); - } - - return r; -} - void OfflineRecognizerConfig::Register(ParseOptions *po) { feat_config.Register(po); model_config.Register(po); @@ -65,90 +34,8 @@ std::string OfflineRecognizerConfig::ToString() const { return os.str(); } -class OfflineRecognizer::Impl { - public: - explicit Impl(const OfflineRecognizerConfig &config) - : config_(config), - symbol_table_(config_.model_config.tokens), - model_(std::make_unique(config_.model_config)) { - if (config_.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); - } else if (config_.decoding_method == "modified_beam_search") { - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); - exit(-1); - } else { - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config_.decoding_method.c_str()); - exit(-1); - } - } - - std::unique_ptr CreateStream() const { - return std::make_unique(config_.feat_config); - } - - void DecodeStreams(OfflineStream **ss, int32_t n) const { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - int32_t feat_dim = ss[0]->FeatureDim(); - - std::vector features; - - features.reserve(n); - - std::vector> features_vec(n); - std::vector features_length_vec(n); - for (int32_t i = 0; i != n; ++i) { - auto f = ss[i]->GetFrames(); - int32_t num_frames = f.size() / feat_dim; - - features_length_vec[i] = num_frames; - features_vec[i] = std::move(f); - - std::array shape = {num_frames, feat_dim}; - - Ort::Value x = Ort::Value::CreateTensor( - memory_info, features_vec[i].data(), features_vec[i].size(), - shape.data(), shape.size()); - features.push_back(std::move(x)); - } - - std::vector features_pointer(n); - for (int32_t i = 0; i != n; ++i) { - features_pointer[i] = &features[i]; - } - - std::array features_length_shape = {n}; - Ort::Value x_length = Ort::Value::CreateTensor( - memory_info, features_length_vec.data(), n, - features_length_shape.data(), features_length_shape.size()); - - Ort::Value x = PadSequence(model_->Allocator(), features_pointer, - -23.025850929940457f); - - auto t = model_->RunEncoder(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); - - int32_t frame_shift_ms = 10; - for (int32_t i = 0; i != n; ++i) { - auto r = Convert(results[i], symbol_table_, frame_shift_ms, - model_->SubsamplingFactor()); - - ss[i]->SetResult(r); - } - } - - private: - OfflineRecognizerConfig config_; - SymbolTable symbol_table_; - std::unique_ptr model_; - std::unique_ptr decoder_; -}; - OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) - : impl_(std::make_unique(config)) {} + : impl_(OfflineRecognizerImpl::Create(config)) {} OfflineRecognizer::~OfflineRecognizer() = default; diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 49423a03..e02e3095 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -52,6 +52,8 @@ struct OfflineRecognizerConfig { std::string ToString() const; }; +class OfflineRecognizerImpl; + class OfflineRecognizer { public: ~OfflineRecognizer(); @@ -78,8 +80,7 @@ class OfflineRecognizer { void DecodeStreams(OfflineStream **ss, int32_t n) const; private: - class Impl; - std::unique_ptr impl_; + std::unique_ptr impl_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc index 6e98e8c1..44eb70c0 100644 --- a/sherpa-onnx/csrc/text-utils.cc +++ b/sherpa-onnx/csrc/text-utils.cc @@ -5,6 +5,8 @@ #include "sherpa-onnx/csrc/text-utils.h" +#include + #include #include @@ -27,4 +29,31 @@ void SplitStringToVector(const std::string &full, const char *delim, } } +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out) { + assert(out != nullptr); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); ++i) { + // assume atof never fails + (*out)[i] = atof(split[i].c_str()); + } + return true; +} + +// Instantiate the template above for float and double. +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h index 1f7d544e..6c91b805 100644 --- a/sherpa-onnx/csrc/text-utils.h +++ b/sherpa-onnx/csrc/text-utils.h @@ -80,6 +80,12 @@ bool SplitStringToIntegers(const std::string &full, const char *delim, return true; } +// This is defined for F = float and double. +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_