Fix reading hotwords file for android (#354)
This commit is contained in:
@@ -12,6 +12,13 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include <strstream>
|
||||||
|
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/online-lm.h"
|
#include "sherpa-onnx/csrc/online-lm.h"
|
||||||
@@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
sym_(config.model_config.tokens),
|
sym_(config.model_config.tokens),
|
||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
if (!config_.hotwords_file.empty()) {
|
|
||||||
InitHotwords();
|
|
||||||
}
|
|
||||||
if (sym_.contains("<unk>")) {
|
if (sym_.contains("<unk>")) {
|
||||||
unk_id_ = sym_["<unk>"];
|
unk_id_ = sym_["<unk>"];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
|
if (!config_.hotwords_file.empty()) {
|
||||||
|
InitHotwords();
|
||||||
|
}
|
||||||
|
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OnlineLM::Create(config.lm_config);
|
lm_ = OnlineLM::Create(config.lm_config);
|
||||||
}
|
}
|
||||||
@@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
|
#if 0
|
||||||
|
// TODO(fangjun): Implement it
|
||||||
|
if (!config_.lm_config.model.empty()) {
|
||||||
|
lm_ = OnlineLM::Create(mgr, config.lm_config);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (!config_.hotwords_file.empty()) {
|
||||||
|
InitHotwords(mgr);
|
||||||
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, unk_id_);
|
config_.lm_config.scale, unk_id_);
|
||||||
@@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
s->Reset();
|
s->Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
void InitHotwords() {
|
void InitHotwords() {
|
||||||
// each line in hotwords_file contains space-separated words
|
// each line in hotwords_file contains space-separated words
|
||||||
|
|
||||||
@@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
#if __ANDROID_API__ >= 9
|
||||||
|
void InitHotwords(AAssetManager *mgr) {
|
||||||
|
// each line in hotwords_file contains space-separated words
|
||||||
|
|
||||||
|
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||||
|
|
||||||
|
std::istrstream is(buf.data(), buf.size());
|
||||||
|
|
||||||
|
if (!is) {
|
||||||
|
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||||
|
config_.hotwords_file.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||||
|
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
hotwords_graph_ =
|
||||||
|
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void InitOnlineStream(OnlineStream *stream) const {
|
void InitOnlineStream(OnlineStream *stream) const {
|
||||||
auto r = decoder_->GetEmptyResult();
|
auto r = decoder_->GetEmptyResult();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user