Add two-pass speech recognition Android/iOS demo (#304)

This commit is contained in:
Fangjun Kuang
2023-09-12 15:40:16 +08:00
committed by GitHub
parent 8982984ea2
commit debab7c091
97 changed files with 3546 additions and 57 deletions

View File

@@ -347,6 +347,8 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
const auto &text = result.text;
auto r = new SherpaOnnxOfflineRecognizerResult;
memset(r, 0, sizeof(SherpaOnnxOfflineRecognizerResult));
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;

View File

@@ -100,4 +100,42 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}
{
auto buffer = ReadFile(mgr, filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(mgr, config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
return nullptr;
}
#endif
} // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <string>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -16,9 +21,15 @@ namespace sherpa_onnx {
class OfflineCtcModel {
public:
virtual ~OfflineCtcModel() = default;
static std::unique_ptr<OfflineCtcModel> Create(
const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineCtcModel> Create(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.

View File

@@ -16,6 +16,13 @@ std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(config);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineLM> OfflineLM::Create(AAssetManager *mgr,
const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(mgr, config);
}
#endif
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
// compute the max token seq so that we know how much space to allocate

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
@@ -20,6 +25,11 @@ class OfflineLM {
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineLM> Create(AAssetManager *mgr,
const OfflineLMConfig &config);
#endif
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.

View File

@@ -19,9 +19,21 @@ class OfflineNemoEncDecCtcModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.nemo_ctc.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
@@ -57,10 +69,8 @@ class OfflineNemoEncDecCtcModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }
private:
void Init() {
auto buf = ReadFile(config_.nemo_ctc.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -104,6 +114,12 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -23,6 +28,12 @@ namespace sherpa_onnx {
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
public:
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineNemoEncDecCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineNemoEncDecCtcModel() override;
/** Run the forward method of the model.

View File

@@ -21,9 +21,21 @@ class OfflineParaformerModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.paraformer.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.paraformer.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
@@ -49,10 +61,8 @@ class OfflineParaformerModel::Impl {
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init() {
auto buf = ReadFile(config_.paraformer.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -101,6 +111,12 @@ class OfflineParaformerModel::Impl {
OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineParaformerModel::~OfflineParaformerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -16,6 +21,11 @@ namespace sherpa_onnx {
class OfflineParaformerModel {
public:
explicit OfflineParaformerModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineParaformerModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineParaformerModel();
/** Run the forward method of the model.

View File

@@ -10,6 +10,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-model.h"
@@ -46,10 +51,24 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(OfflineCtcModel::Create(config_.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerCtcImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
Init();
}
#endif
void Init() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
SHERPA_ONNX_LOGE(
@@ -69,7 +88,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}

View File

@@ -132,4 +132,121 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
exit(-1);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config) {
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
Ort::SessionOptions sess_opts;
std::string model_filename;
if (!config.model_config.transducer.encoder_filename.empty()) {
model_filename = config.model_config.transducer.encoder_filename;
} else if (!config.model_config.paraformer.model.empty()) {
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
SHERPA_ONNX_LOGE("Please provide a model");
exit(-1);
}
auto buf = ReadFile(mgr, model_filename);
auto encoder_sess =
std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
Ort::ModelMetadata meta_data = encoder_sess->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
"\n"
"(0) Transducer models from icefall"
"\n "
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"pruned_transducer_stateless7/export-onnx.py#L303"
"\n"
"(1) Nemo CTC models\n "
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"\n"
"(2) Paraformer"
"\n "
"https://huggingface.co/csukuangfj/"
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n "
"(3) Whisper"
"\n "
"(4) Tdnn models of the yesno recipe from icefall"
"\n "
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
"\n"
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer" ||
model_type == "zipformer2") {
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
}
if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (model_type == "tdnn") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
}
SHERPA_ONNX_LOGE(
"\nUnsupported model_type: %s\n"
"We support only the following model types at present: \n"
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n",
model_type.c_str());
exit(-1);
}
#endif
} // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -19,6 +24,11 @@ class OfflineRecognizerImpl {
static std::unique_ptr<OfflineRecognizerImpl> Create(
const OfflineRecognizerConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineRecognizerImpl> Create(
AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(

View File

@@ -11,6 +11,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
@@ -100,6 +105,28 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
config_.feat_config.normalize_samples = false;
}
#if __ANDROID_API__ >= 9
OfflineRecognizerParaformerImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineParaformerModel>(mgr,
config.model_config)) {
if (config.decoding_method == "greedy_search") {
int32_t eos_id = symbol_table_["</s>"];
decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}

View File

@@ -10,6 +10,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
@@ -73,6 +78,32 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
#if __ANDROID_API__ >= 9
explicit OfflineRecognizerTransducerImpl(
AAssetManager *mgr, const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(mgr,
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default

View File

@@ -12,6 +12,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -253,16 +258,32 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
Init();
}
#if __ANDROID_API__ >= 9
OfflineRecognizerWhisperImpl(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
Init();
}
#endif
void Init() {
// tokens.txt from whisper is base64 encoded, so we need to decode it
symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
config_.model_config.whisper, model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
config.decoding_method.c_str());
config_.decoding_method.c_str());
exit(-1);
}
}

View File

@@ -58,6 +58,12 @@ std::string OfflineRecognizerConfig::ToString() const {
return os.str();
}
#if __ANDROID_API__ >= 9
OfflineRecognizer::OfflineRecognizer(AAssetManager *mgr,
const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(mgr, config)) {}
#endif
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(config)) {}

View File

@@ -9,6 +9,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -55,6 +60,10 @@ class OfflineRecognizer {
public:
~OfflineRecognizer();
#if __ANDROID_API__ >= 9
OfflineRecognizer(AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
/// Create a stream for decoding.

View File

@@ -11,8 +11,8 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@@ -23,9 +23,21 @@ class OfflineRnnLM::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{GetSessionOptions(config)},
allocator_{} {
Init(config);
auto buf = ReadFile(config_.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineLMConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{GetSessionOptions(config)},
allocator_{} {
auto buf = ReadFile(mgr, config_.model);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
@@ -37,10 +49,8 @@ class OfflineRnnLM::Impl {
}
private:
void Init(const OfflineLMConfig &config) {
auto buf = ReadFile(config_.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -66,6 +76,11 @@ class OfflineRnnLM::Impl {
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineRnnLM::OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineRnnLM::~OfflineRnnLM() = default;
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {

View File

@@ -7,6 +7,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-lm.h"
@@ -19,6 +24,10 @@ class OfflineRnnLM : public OfflineLM {
explicit OfflineRnnLM(const OfflineLMConfig &config);
#if __ANDROID_API__ >= 9
OfflineRnnLM(AAssetManager *mgr, const OfflineLMConfig &config);
#endif
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.

View File

@@ -19,9 +19,21 @@ class OfflineTdnnCtcModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
Init();
auto buf = ReadFile(config_.tdnn.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.tdnn.model);
Init(buf.data(), buf.size());
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
auto nnet_out =
sess_->Run({}, input_names_ptr_.data(), &features, 1,
@@ -48,10 +60,8 @@ class OfflineTdnnCtcModel::Impl {
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init() {
auto buf = ReadFile(config_.tdnn.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
@@ -90,6 +100,12 @@ class OfflineTdnnCtcModel::Impl {
OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -22,6 +27,11 @@ namespace sherpa_onnx {
class OfflineTdnnCtcModel : public OfflineCtcModel {
public:
explicit OfflineTdnnCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTdnnCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineTdnnCtcModel() override;
/** Run the forward method of the model.

View File

@@ -38,6 +38,29 @@ class OfflineTransducerModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.transducer.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.transducer.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
@@ -221,6 +244,12 @@ class OfflineTransducerModel::Impl {
OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTransducerModel::OfflineTransducerModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTransducerModel::~OfflineTransducerModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(

View File

@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -19,6 +24,11 @@ struct OfflineTransducerDecoderResult;
class OfflineTransducerModel {
public:
explicit OfflineTransducerModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTransducerModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineTransducerModel();
/** Run the encoder.

View File

@@ -35,6 +35,24 @@ class OfflineWhisperModel::Impl {
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#endif
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), &features, 1,
@@ -226,6 +244,12 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(

View File

@@ -11,6 +11,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
@@ -19,6 +24,11 @@ namespace sherpa_onnx {
class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineWhisperModel();
/** Run the encoder model.

View File

@@ -20,6 +20,7 @@
#include <fstream>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/wave-reader.h"
@@ -53,7 +54,7 @@ class SherpaOnnx {
stream_->InputFinished();
}
const std::string GetText() const {
std::string GetText() const {
auto result = recognizer_.GetResult(stream_.get());
return result.text;
}
@@ -67,7 +68,13 @@ class SherpaOnnx {
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); }
void Reset(bool recreate) {
if (recreate) {
stream_ = recognizer_.CreateStream();
} else {
recognizer_.Reset(stream_.get());
}
}
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
@@ -77,6 +84,28 @@ class SherpaOnnx {
int32_t input_sample_rate_ = -1;
};
class SherpaOnnxOffline {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxOffline(AAssetManager *mgr, const OfflineRecognizerConfig &config)
: recognizer_(mgr, config) {}
#endif
explicit SherpaOnnxOffline(const OfflineRecognizerConfig &config)
: recognizer_(config) {}
std::string Decode(int32_t sample_rate, const float *samples, int32_t n) {
auto stream = recognizer_.CreateStream();
stream->AcceptWaveform(sample_rate, samples, n);
recognizer_.DecodeStream(stream.get());
return stream->GetResult().text;
}
private:
OfflineRecognizer recognizer_;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
@@ -248,6 +277,122 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
return ans;
}
static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
OfflineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.model = p;
env->ReleaseStringUTFChars(s, p);
// whisper
fid = env->GetFieldID(model_config_cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");
jobject whisper_config = env->GetObjectField(model_config, fid);
jclass whisper_config_cls = env->GetObjectClass(whisper_config);
fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
@@ -287,10 +432,48 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxOffline(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxOffline(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Reset();
model->Reset(recreate);
}
SHERPA_ONNX_EXTERN_C
@@ -328,6 +511,22 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxOffline *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto text = model->Decode(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {