Add two-pass speech recognition Android/iOS demo (#304)
This commit is contained in:
@@ -100,4 +100,42 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
|
||||
std::string filename;
|
||||
if (!config.nemo_ctc.model.empty()) {
|
||||
filename = config.nemo_ctc.model;
|
||||
} else if (!config.tdnn.model.empty()) {
|
||||
filename = config.tdnn.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(mgr, filename);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kEncDecCTCModelBPE:
|
||||
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kTdnn:
|
||||
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
@@ -16,9 +21,15 @@ namespace sherpa_onnx {
|
||||
class OfflineCtcModel {
|
||||
public:
|
||||
virtual ~OfflineCtcModel() = default;
|
||||
|
||||
static std::unique_ptr<OfflineCtcModel> Create(
|
||||
const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OfflineCtcModel> Create(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
|
||||
@@ -16,6 +16,13 @@ std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
|
||||
return std::make_unique<OfflineRnnLM>(config);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OfflineLM> OfflineLM::Create(AAssetManager *mgr,
|
||||
const OfflineLMConfig &config) {
|
||||
return std::make_unique<OfflineRnnLM>(mgr, config);
|
||||
}
|
||||
#endif
|
||||
|
||||
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) {
|
||||
// compute the max token seq so that we know how much space to allocate
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
@@ -20,6 +25,11 @@ class OfflineLM {
|
||||
|
||||
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OfflineLM> Create(AAssetManager *mgr,
|
||||
const OfflineLMConfig &config);
|
||||
#endif
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
|
||||
@@ -19,9 +19,21 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
Init();
|
||||
auto buf = ReadFile(config_.nemo_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.nemo_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<int64_t> shape =
|
||||
@@ -57,10 +69,8 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
auto buf = ReadFile(config_.nemo_ctc.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
@@ -104,6 +114,12 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
@@ -23,6 +28,12 @@ namespace sherpa_onnx {
|
||||
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineNemoEncDecCtcModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineNemoEncDecCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
|
||||
@@ -21,9 +21,21 @@ class OfflineParaformerModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
Init();
|
||||
auto buf = ReadFile(config_.paraformer.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.paraformer.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
@@ -49,10 +61,8 @@ class OfflineParaformerModel::Impl {
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
auto buf = ReadFile(config_.paraformer.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
@@ -101,6 +111,12 @@ class OfflineParaformerModel::Impl {
|
||||
OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
@@ -16,6 +21,11 @@ namespace sherpa_onnx {
|
||||
class OfflineParaformerModel {
|
||||
public:
|
||||
explicit OfflineParaformerModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineParaformerModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineParaformerModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
|
||||
@@ -10,6 +10,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
@@ -46,10 +51,24 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(OfflineCtcModel::Create(config_.model_config)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerCtcImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
|
||||
Init();
|
||||
}
|
||||
#endif
|
||||
|
||||
void Init() {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>") &&
|
||||
!symbol_table_.contains("<eps>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
@@ -69,7 +88,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,4 +132,121 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
|
||||
if (!config.model_config.model_type.empty()) {
|
||||
const auto &model_type = config.model_config.model_type;
|
||||
if (model_type == "transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "tdnn") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
std::string model_filename;
|
||||
if (!config.model_config.transducer.encoder_filename.empty()) {
|
||||
model_filename = config.model_config.transducer.encoder_filename;
|
||||
} else if (!config.model_config.paraformer.model.empty()) {
|
||||
model_filename = config.model_config.paraformer.model;
|
||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else if (!config.model_config.tdnn.model.empty()) {
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please provide a model");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto buf = ReadFile(mgr, model_filename);
|
||||
|
||||
auto encoder_sess =
|
||||
std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata();
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
auto model_type_ptr =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type_ptr) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n\n"
|
||||
"Please refer to the following URLs to add metadata"
|
||||
"\n"
|
||||
"(0) Transducer models from icefall"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"pruned_transducer_stateless7/export-onnx.py#L303"
|
||||
"\n"
|
||||
"(1) Nemo CTC models\n "
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||
"\n"
|
||||
"(2) Paraformer"
|
||||
"\n "
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||
"\n "
|
||||
"(3) Whisper"
|
||||
"\n "
|
||||
"(4) Tdnn models of the yesno recipe from icefall"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
std::string model_type(model_type_ptr.get());
|
||||
|
||||
if (model_type == "conformer" || model_type == "zipformer" ||
|
||||
model_type == "zipformer2") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "tdnn") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\nUnsupported model_type: %s\n"
|
||||
"We support only the following model types at present: \n"
|
||||
" - Non-streaming transducer models from icefall\n"
|
||||
" - Non-streaming Paraformer models from FunASR\n"
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
@@ -19,6 +24,11 @@ class OfflineRecognizerImpl {
|
||||
static std::unique_ptr<OfflineRecognizerImpl> Create(
|
||||
const OfflineRecognizerConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OfflineRecognizerImpl> Create(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config);
|
||||
#endif
|
||||
|
||||
virtual ~OfflineRecognizerImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream(
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||
@@ -100,6 +105,28 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerParaformerImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineParaformerModel>(mgr,
|
||||
config.model_config)) {
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
int32_t eos_id = symbol_table_["</s>"];
|
||||
decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Paraformer models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
@@ -73,6 +78,32 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit OfflineRecognizerTransducerImpl(
|
||||
AAssetManager *mgr, const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
||||
config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
} else if (config_.decoding_method == "modified_beam_search") {
|
||||
if (!config_.lm_config.model.empty()) {
|
||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -253,16 +258,32 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void Init() {
|
||||
// tokens.txt from whisper is base64 encoded, so we need to decode it
|
||||
symbol_table_.ApplyBase64Decode();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
|
||||
config_.model_config.whisper, model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for whisper. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,12 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizer::OfflineRecognizer(AAssetManager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: impl_(OfflineRecognizerImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
||||
: impl_(OfflineRecognizerImpl::Create(config)) {}
|
||||
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
@@ -55,6 +60,10 @@ class OfflineRecognizer {
|
||||
public:
|
||||
~OfflineRecognizer();
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRecognizer(AAssetManager *mgr, const OfflineRecognizerConfig &config);
|
||||
#endif
|
||||
|
||||
explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
|
||||
|
||||
/// Create a stream for decoding.
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -23,9 +23,21 @@ class OfflineRnnLM::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{GetSessionOptions(config)},
|
||||
allocator_{} {
|
||||
Init(config);
|
||||
auto buf = ReadFile(config_.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineLMConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{GetSessionOptions(config)},
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
|
||||
|
||||
@@ -37,10 +49,8 @@ class OfflineRnnLM::Impl {
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(const OfflineLMConfig &config) {
|
||||
auto buf = ReadFile(config_.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
@@ -66,6 +76,11 @@ class OfflineRnnLM::Impl {
|
||||
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRnnLM::OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineRnnLM::~OfflineRnnLM() = default;
|
||||
|
||||
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||
@@ -19,6 +24,10 @@ class OfflineRnnLM : public OfflineLM {
|
||||
|
||||
explicit OfflineRnnLM(const OfflineLMConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config);
|
||||
#endif
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
|
||||
@@ -19,9 +19,21 @@ class OfflineTdnnCtcModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
Init();
|
||||
auto buf = ReadFile(config_.tdnn.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.tdnn.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
|
||||
auto nnet_out =
|
||||
sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||
@@ -48,10 +60,8 @@ class OfflineTdnnCtcModel::Impl {
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
auto buf = ReadFile(config_.tdnn.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
@@ -90,6 +100,12 @@ class OfflineTdnnCtcModel::Impl {
|
||||
OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
@@ -22,6 +27,11 @@ namespace sherpa_onnx {
|
||||
class OfflineTdnnCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTdnnCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineTdnnCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
|
||||
@@ -38,6 +38,29 @@ class OfflineTransducerModel::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||
@@ -221,6 +244,12 @@ class OfflineTransducerModel::Impl {
|
||||
OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTransducerModel::OfflineTransducerModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineTransducerModel::~OfflineTransducerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
@@ -19,6 +24,11 @@ struct OfflineTransducerDecoderResult;
|
||||
class OfflineTransducerModel {
|
||||
public:
|
||||
explicit OfflineTransducerModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTransducerModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineTransducerModel();
|
||||
|
||||
/** Run the encoder.
|
||||
|
||||
@@ -35,6 +35,24 @@ class OfflineWhisperModel::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), &features, 1,
|
||||
@@ -226,6 +244,12 @@ class OfflineWhisperModel::Impl {
|
||||
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
@@ -19,6 +24,11 @@ namespace sherpa_onnx {
|
||||
class OfflineWhisperModel {
|
||||
public:
|
||||
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineWhisperModel();
|
||||
|
||||
/** Run the encoder model.
|
||||
|
||||
Reference in New Issue
Block a user