Add two-pass speech recognition Android/iOS demo (#304)

This commit is contained in:
Fangjun Kuang
2023-09-12 15:40:16 +08:00
committed by GitHub
parent 8982984ea2
commit debab7c091
97 changed files with 3546 additions and 57 deletions

View File

@@ -100,4 +100,42 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}
{
auto buffer = ReadFile(mgr, filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
return nullptr;
}
#endif
} // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <string>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -16,9 +21,15 @@ namespace sherpa_onnx {
class OfflineCtcModel {
public:
virtual ~OfflineCtcModel() = default;
static std::unique_ptr<OfflineCtcModel> Create(
const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineCtcModel> Create(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.

View File

@@ -16,6 +16,13 @@ std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(config);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineLM> OfflineLM::Create(AAssetManager *mgr,
const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(mgr, config);
}
#endif
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
// compute the max token seq so that we know how much space to allocate

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
@@ -20,6 +25,11 @@ class OfflineLM {
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineLM> Create(AAssetManager *mgr,
const OfflineLMConfig &config);
#endif
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.

View File

@@ -19,9 +19,21 @@ class OfflineNemoEncDecCtcModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
@@ -57,10 +69,8 @@ class OfflineNemoEncDecCtcModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }
private:
void Init() {
auto buf = ReadFile(config_.nemo_ctc.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -104,6 +114,12 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -23,6 +28,12 @@ namespace sherpa_onnx {
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
public:
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineNemoEncDecCtcModel() override;
/** Run the forward method of the model.

View File

@@ -21,9 +21,21 @@ class OfflineParaformerModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.paraformer.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.paraformer.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
@@ -49,10 +61,8 @@ class OfflineParaformerModel::Impl {
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init() {
auto buf = ReadFile(config_.paraformer.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -101,6 +111,12 @@ class OfflineParaformerModel::Impl {
OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineParaformerModel::~OfflineParaformerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -16,6 +21,11 @@ namespace sherpa_onnx {
class OfflineParaformerModel {
public:
explicit OfflineParaformerModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineParaformerModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineParaformerModel();
/** Run the forward method of the model.

View File

@@ -10,6 +10,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-model.h"
@@ -46,10 +51,24 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(OfflineCtcModel::Create(config_.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerCtcImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
Init();
}
#endif
void Init() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
SHERPA_ONNX_LOGE(
@@ -69,7 +88,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}

View File

@@ -132,4 +132,121 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
exit(-1);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
Ort::SessionOptions sess_opts;
std::string model_filename;
if (!config.model_config.transducer.encoder_filename.empty()) {
model_filename = config.model_config.transducer.encoder_filename;
} else if (!config.model_config.paraformer.model.empty()) {
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
SHERPA_ONNX_LOGE("Please provide a model");
exit(-1);
}
auto buf = ReadFile(mgr, model_filename);
auto encoder_sess =
std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
"\n"
"(0) Transducer models from icefall"
"\n "
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"pruned_transducer_stateless7/export-onnx.py#L303"
"\n"
"(1) Nemo CTC models\n "
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"\n"
"(2) Paraformer"
"\n "
"https://huggingface.co/csukuangfj/"
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n "
"(3) Whisper"
"\n "
"(4) Tdnn models of the yesno recipe from icefall"
"\n "
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
"\n"
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
}
if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}
SHERPA_ONNX_LOGE(
"\nUnsupported model_type: %s\n"
"We support only the following model types at present: \n"
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n",
model_type.c_str());
exit(-1);
}
#endif
} // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -19,6 +24,11 @@ class OfflineRecognizerImpl {
static std::unique_ptr<OfflineRecognizerImpl> Create(
const OfflineRecognizerConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineRecognizerImpl> Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(

View File

@@ -11,6 +11,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
@@ -100,6 +105,28 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
#if __ANDROID_API__ >= 9
OfflineRecognizerParaformerImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineParaformerModel>(mgr,
config.model_config)) {
if (config.decoding_method == "greedy_search") {
int32_t eos_id = symbol_table_["</s>"];
decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}

View File

@@ -10,6 +10,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
@@ -73,6 +78,32 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(mgr,
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default

View File

@@ -12,6 +12,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -253,16 +258,32 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
Init();
}
#endif
void Init() {
// tokens.txt from whisper is base64 encoded, so we need to decode it
symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
config_.model_config.whisper, model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}

View File

@@ -58,6 +58,12 @@ std::string OfflineRecognizerConfig::ToString() const {
return os.str();
}
#if __ANDROID_API__ >= 9
OfflineRecognizer::OfflineRecognizer(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(mgr, config)) {}
#endif
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(config)) {}

View File

@@ -9,6 +9,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -55,6 +60,10 @@ class OfflineRecognizer {
public:
~OfflineRecognizer();
#if __ANDROID_API__ >= 9
OfflineRecognizer(AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
/// Create a stream for decoding.

View File

@@ -11,8 +11,8 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@@ -23,9 +23,21 @@ class OfflineRnnLM::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{GetSessionOptions(config)},
allocator_{} {
Init(config);
auto buf = ReadFile(config_.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineLMConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{GetSessionOptions(config)},
allocator_{} {
auto buf = ReadFile(mgr, config_.model);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
@@ -37,10 +49,8 @@ class OfflineRnnLM::Impl {
}
private:
void Init(const OfflineLMConfig &config) {
auto buf = ReadFile(config_.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -66,6 +76,11 @@ class OfflineRnnLM::Impl {
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineRnnLM::OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineRnnLM::~OfflineRnnLM() = default;
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {

View File

@@ -7,6 +7,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-lm.h"
@@ -19,6 +24,10 @@ class OfflineRnnLM : public OfflineLM {
explicit OfflineRnnLM(const OfflineLMConfig &config);
#if __ANDROID_API__ >= 9
OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config);
#endif
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.

View File

@@ -19,9 +19,21 @@ class OfflineTdnnCtcModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.tdnn.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.tdnn.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
auto nnet_out =
sess_->Run({}, input_names_ptr_.data(), &features, 1,
@@ -48,10 +60,8 @@ class OfflineTdnnCtcModel::Impl {
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init() {
auto buf = ReadFile(config_.tdnn.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -90,6 +100,12 @@ class OfflineTdnnCtcModel::Impl {
OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -22,6 +27,11 @@ namespace sherpa_onnx {
class OfflineTdnnCtcModel : public OfflineCtcModel {
public:
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTdnnCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineTdnnCtcModel() override;
/** Run the forward method of the model.

View File

@@ -38,6 +38,29 @@ class OfflineTransducerModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
@@ -221,6 +244,12 @@ class OfflineTransducerModel::Impl {
OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTransducerModel::OfflineTransducerModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTransducerModel::~OfflineTransducerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -19,6 +24,11 @@ struct OfflineTransducerDecoderResult;
class OfflineTransducerModel {
public:
explicit OfflineTransducerModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTransducerModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineTransducerModel();
/** Run the encoder.

View File

@@ -35,6 +35,24 @@ class OfflineWhisperModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#endif
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), &features, 1,
@@ -226,6 +244,12 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(

View File

@@ -11,6 +11,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -19,6 +24,11 @@ namespace sherpa_onnx {
class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineWhisperModel();
/** Run the encoder model.