Add build script for Android armv8a (#58)
This commit is contained in:
@@ -26,6 +26,10 @@ endif()
|
||||
|
||||
add_library(sherpa-onnx-core ${sources})
|
||||
|
||||
if(ANDROID_NDK)
|
||||
target_link_libraries(sherpa-onnx-core android log)
|
||||
endif()
|
||||
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
InitEncoder(config.encoder_filename);
|
||||
InitDecoder(config.decoder_filename);
|
||||
InitJoiner(config.joiner_filename);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void OnlineLstmTransducerModel::InitEncoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineLstmTransducerModel::InitDecoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineLstmTransducerModel::InitJoiner(void *model_data,
|
||||
size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
public:
|
||||
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineLstmTransducerModel(AAssetManager *mgr,
|
||||
const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
void InitEncoder(void *model_data, size_t model_data_length);
|
||||
void InitDecoder(void *model_data, size_t model_data_length);
|
||||
void InitJoiner(void *model_data, size_t model_data_length);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
|
||||
@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
|
||||
const OnlineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OnlineRecognizer::~OnlineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
|
||||
class OnlineRecognizer {
|
||||
public:
|
||||
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
||||
#endif
|
||||
|
||||
~OnlineRecognizer();
|
||||
|
||||
/// Create a stream for decoding.
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -18,15 +23,16 @@ enum class ModelType {
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(
|
||||
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (config.debug) {
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
const OnlineTransducerModelConfig &config) {
|
||||
auto model_type = GetModelType(config);
|
||||
auto buffer = ReadFile(config.encoder_filename);
|
||||
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kLstm:
|
||||
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
|
||||
auto buffer = ReadFile(mgr, config.encoder_filename);
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kLstm:
|
||||
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
|
||||
case ModelType::kZipformer:
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
@@ -22,6 +27,11 @@ class OnlineTransducerModel {
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
/** Stack a list of individual states into a batch.
|
||||
*
|
||||
* It is the inverse operation of `UnStackStates`.
|
||||
|
||||
@@ -13,6 +13,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
InitEncoder(config.encoder_filename);
|
||||
InitDecoder(config.decoder_filename);
|
||||
InitJoiner(config.joiner_filename);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
|
||||
size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
explicit OnlineZipformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineZipformerTransducerModel(AAssetManager *mgr,
|
||||
const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
void InitEncoder(void *model_data, size_t model_data_length);
|
||||
void InitDecoder(void *model_data, size_t model_data_length);
|
||||
void InitJoiner(void *model_data, size_t model_data_length);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
|
||||
@@ -3,9 +3,16 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#include "android/log.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename) {
|
||||
std::ifstream input(filename, std::ios::binary);
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||
return buffer;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
|
||||
"Read binary file: Load %s failed", filename.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
|
||||
AAsset_close(asset);
|
||||
|
||||
std::vector<char> buffer(p, p + asset_length);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -14,6 +14,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
|
||||
std::fill(p, p + n, value);
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -7,11 +7,32 @@
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <strstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
SymbolTable::SymbolTable(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
|
||||
auto buf = ReadFile(mgr, filename);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
Init(is);
|
||||
}
|
||||
#endif
|
||||
|
||||
void SymbolTable::Init(std::istream &is) {
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// It manages mapping between symbols and integer IDs.
|
||||
@@ -22,6 +27,10 @@ class SymbolTable {
|
||||
/// Fields are separated by space(s).
|
||||
explicit SymbolTable(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
/// Return a string representation of this symbol table
|
||||
std::string ToString() const;
|
||||
|
||||
@@ -36,6 +45,9 @@ class SymbolTable {
|
||||
/// Return true if there is a given symbol in the symbol table.
|
||||
bool contains(const std::string &sym) const;
|
||||
|
||||
private:
|
||||
void Init(std::istream &is);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, int32_t> sym2id_;
|
||||
std::unordered_map<int32_t, std::string> id2sym_;
|
||||
|
||||
Reference in New Issue
Block a user