Add build script for Android armv8a (#58)

This commit is contained in:
Fangjun Kuang
2023-02-22 22:36:05 +08:00
committed by GitHub
parent ef93dcd733
commit 5a5d029490
16 changed files with 398 additions and 72 deletions

61
build-android-arm64-v8a.sh Executable file
View File

@@ -0,0 +1,61 @@
#!/usr/bin/env bash
set -ex
dir=build-android-arm64-v8a
mkdir -p $dir
cd $dir
# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
# (optional) remove the hardcoded debug flag in Android NDK android-ndk
# issue: https://github.com/android/ndk/issues/243
#
# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23
# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23
#
# delete "-g" line
#
# list(APPEND ANDROID_COMPILER_FLAGS
# -g
# -DANDROID
if [ -z $ANDROID_NDK ]; then
ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
# or use
# ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
#
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"
if [ ! -d $ANDROID_NDK ]; then
# For macOS, I have installed Android Studio, select the menu
# Tools -> SDK manager -> Android SDK
# and set "Android SDK location" to /Users/fangjun/software/my-android
ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670
fi
fi
if [ ! -d $ANDROID_NDK ]; then
echo Please set the environment variable ANDROID_NDK before you run this script
exit 1
fi
echo "ANDROID_NDK: $ANDROID_NDK"
sleep 1
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DCMAKE_BUILD_TYPE=Release \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
-DCMAKE_INSTALL_PREFIX=./install \
-DANDROID_ABI="arm64-v8a" \
-DANDROID_PLATFORM=android-21 ..
# make VERBOSE=1 -j4
make -j4
make install/strip

View File

@@ -1,7 +1,9 @@
function(download_onnxruntime) function(download_onnxruntime)
include(FetchContent) include(FetchContent)
if(CMAKE_SYSTEM_NAME STREQUAL Linux) message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
# For embedded systems # For embedded systems
set(possible_file_locations set(possible_file_locations
@@ -13,8 +15,7 @@ function(download_onnxruntime)
) )
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
else()
# If you don't have access to the Internet, # If you don't have access to the Internet,
# please pre-download onnxruntime # please pre-download onnxruntime
set(possible_file_locations set(possible_file_locations
@@ -33,7 +34,6 @@ function(download_onnxruntime)
# #
# ./include # ./include
# It contains all the needed header files # It contains all the needed header files
endif()
elseif(APPLE) elseif(APPLE)
# If you don't have access to the Internet, # If you don't have access to the Internet,
# please pre-download onnxruntime # please pre-download onnxruntime
@@ -69,6 +69,8 @@ function(download_onnxruntime)
# ./include # ./include
# It contains all the needed header files # It contains all the needed header files
else() else()
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
endif() endif()
@@ -91,11 +93,15 @@ function(download_onnxruntime)
endif() endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
if(ANDROID)
set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so)
else()
find_library(location_onnxruntime onnxruntime find_library(location_onnxruntime onnxruntime
PATHS PATHS
"${onnxruntime_SOURCE_DIR}/lib" "${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH NO_CMAKE_SYSTEM_PATH
) )
endif()
message(STATUS "location_onnxruntime: ${location_onnxruntime}") message(STATUS "location_onnxruntime: ${location_onnxruntime}")

View File

@@ -26,6 +26,10 @@ endif()
add_library(sherpa-onnx-core ${sources}) add_library(sherpa-onnx-core ${sources})
if(ANDROID_NDK)
target_link_libraries(sherpa-onnx-core android log)
endif()
target_link_libraries(sherpa-onnx-core target_link_libraries(sherpa-onnx-core
onnxruntime onnxruntime
kaldi-native-fbank-core kaldi-native-fbank-core

View File

@@ -12,6 +12,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
sess_opts_.SetIntraOpNumThreads(config.num_threads); sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads); sess_opts_.SetInterOpNumThreads(config.num_threads);
InitEncoder(config.encoder_filename); {
InitDecoder(config.decoder_filename); auto buf = ReadFile(config.encoder_filename);
InitJoiner(config.joiner_filename); InitEncoder(buf.data(), buf.size());
} }
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { {
encoder_sess_ = std::make_unique<Ort::Session>( auto buf = ReadFile(config.decoder_filename);
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(mgr, config.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
void OnlineLstmTransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_, GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_); &encoder_input_names_ptr_);
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
} }
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { void OnlineLstmTransducerModel::InitDecoder(void *model_data,
decoder_sess_ = std::make_unique<Ort::Session>( size_t model_data_length) {
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_, GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_); &decoder_input_names_ptr_);
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
} }
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { void OnlineLstmTransducerModel::InitJoiner(void *model_data,
joiner_sess_ = std::make_unique<Ort::Session>( size_t model_data_length) {
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_, GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_); &joiner_input_names_ptr_);

View File

@@ -9,6 +9,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
public: public:
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
#endif
std::vector<Ort::Value> StackStates( std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override; const std::vector<std::vector<Ort::Value>> &states) const override;
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
OrtAllocator *Allocator() override { return allocator_; } OrtAllocator *Allocator() override { return allocator_; }
private: private:
void InitEncoder(const std::string &encoder_filename); void InitEncoder(void *model_data, size_t model_data_length);
void InitDecoder(const std::string &decoder_filename); void InitDecoder(void *model_data, size_t model_data_length);
void InitJoiner(const std::string &joiner_filename); void InitJoiner(void *model_data, size_t model_data_length);
private: private:
Ort::Env env_; Ort::Env env_;

View File

@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
} }
#if __ANDROID_API__ >= 9
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.tokens),
endpoint_(config_.endpoint_config) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
}
#endif
std::unique_ptr<OnlineStream> CreateStream() const { std::unique_ptr<OnlineStream> CreateStream() const {
auto stream = std::make_unique<OnlineStream>(config_.feat_config); auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetResult(decoder_->GetEmptyResult()); stream->SetResult(decoder_->GetEmptyResult());
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(config)) {} : impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineRecognizer::~OnlineRecognizer() = default; OnlineRecognizer::~OnlineRecognizer() = default;
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {

View File

@@ -8,6 +8,11 @@
#include <memory> #include <memory>
#include <string> #include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-stream.h"
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
class OnlineRecognizer { class OnlineRecognizer {
public: public:
explicit OnlineRecognizer(const OnlineRecognizerConfig &config); explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
#if __ANDROID_API__ >= 9
OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config);
#endif
~OnlineRecognizer(); ~OnlineRecognizer();
/// Create a stream for decoding. /// Create a stream for decoding.

View File

@@ -3,6 +3,11 @@
// Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/online-transducer-model.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
@@ -18,15 +23,16 @@ enum class ModelType {
kUnkown, kUnkown,
}; };
static ModelType GetModelType(const OnlineTransducerModelConfig &config) { static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING); Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts; Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>( auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata(); Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (config.debug) { if (debug) {
std::ostringstream os; std::ostringstream os;
PrintModelMetadata(os, meta_data); PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str()); fprintf(stderr, "%s\n", os.str().c_str());
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const OnlineTransducerModelConfig &config) { const OnlineTransducerModelConfig &config) {
auto model_type = GetModelType(config); auto buffer = ReadFile(config.encoder_filename);
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
switch (model_type) { switch (model_type) {
case ModelType::kLstm: case ModelType::kLstm:
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return nullptr; return nullptr;
} }
#if __ANDROID_API__ >= 9
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
auto buffer = ReadFile(mgr, config.encoder_filename);
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
switch (model_type) {
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
case ModelType::kUnkown:
return nullptr;
}
// unreachable code
return nullptr;
}
#endif
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
@@ -22,6 +27,11 @@ class OnlineTransducerModel {
static std::unique_ptr<OnlineTransducerModel> Create( static std::unique_ptr<OnlineTransducerModel> Create(
const OnlineTransducerModelConfig &config); const OnlineTransducerModelConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OnlineTransducerModel> Create(
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
#endif
/** Stack a list of individual states into a batch. /** Stack a list of individual states into a batch.
* *
* It is the inverse operation of `UnStackStates`. * It is the inverse operation of `UnStackStates`.

View File

@@ -13,6 +13,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
sess_opts_.SetIntraOpNumThreads(config.num_threads); sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads); sess_opts_.SetInterOpNumThreads(config.num_threads);
InitEncoder(config.encoder_filename); {
InitDecoder(config.decoder_filename); auto buf = ReadFile(config.encoder_filename);
InitJoiner(config.joiner_filename); InitEncoder(buf.data(), buf.size());
} }
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { {
encoder_sess_ = std::make_unique<Ort::Session>( auto buf = ReadFile(config.decoder_filename);
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(mgr, config.encoder_filename);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.decoder_filename);
InitDecoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.joiner_filename);
InitJoiner(buf.data(), buf.size());
}
}
#endif
void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_, GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_); &encoder_input_names_ptr_);
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
} }
} }
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
decoder_sess_ = std::make_unique<Ort::Session>( size_t model_data_length) {
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_, GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_); &decoder_input_names_ptr_);
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
} }
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) { void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
joiner_sess_ = std::make_unique<Ort::Session>( size_t model_data_length) {
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
GetInputNames(joiner_sess_.get(), &joiner_input_names_, GetInputNames(joiner_sess_.get(), &joiner_input_names_,
&joiner_input_names_ptr_); &joiner_input_names_ptr_);

View File

@@ -9,6 +9,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
explicit OnlineZipformerTransducerModel( explicit OnlineZipformerTransducerModel(
const OnlineTransducerModelConfig &config); const OnlineTransducerModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineZipformerTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
#endif
std::vector<Ort::Value> StackStates( std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override; const std::vector<std::vector<Ort::Value>> &states) const override;
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
OrtAllocator *Allocator() override { return allocator_; } OrtAllocator *Allocator() override { return allocator_; }
private: private:
void InitEncoder(const std::string &encoder_filename); void InitEncoder(void *model_data, size_t model_data_length);
void InitDecoder(const std::string &decoder_filename); void InitDecoder(void *model_data, size_t model_data_length);
void InitJoiner(const std::string &joiner_filename); void InitJoiner(void *model_data, size_t model_data_length);
private: private:
Ort::Env env_; Ort::Env env_;

View File

@@ -3,9 +3,16 @@
// Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include <fstream>
#include <string> #include <string>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#include "android/log.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
std::vector<char> ReadFile(const std::string &filename) {
std::ifstream input(filename, std::ios::binary);
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
return buffer;
}
#if __ANDROID_API__ >= 9
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
if (!asset) {
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
"Read binary file: Load %s failed", filename.c_str());
exit(-1);
}
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
size_t asset_length = AAsset_getLength(asset);
AAsset_close(asset);
std::vector<char> buffer(p, p + asset_length);
return buffer;
}
#endif
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -14,6 +14,11 @@
#include <string> #include <string>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
std::fill(p, p + n, value); std::fill(p, p + n, value);
} }
std::vector<char> ReadFile(const std::string &filename);
#if __ANDROID_API__ >= 9
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
#endif
} // namespace sherpa_onnx } // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_

View File

@@ -7,11 +7,32 @@
#include <cassert> #include <cassert>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <strstream>
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
namespace sherpa_onnx { namespace sherpa_onnx {
SymbolTable::SymbolTable(const std::string &filename) { SymbolTable::SymbolTable(const std::string &filename) {
std::ifstream is(filename); std::ifstream is(filename);
Init(is);
}
#if __ANDROID_API__ >= 9
SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
auto buf = ReadFile(mgr, filename);
std::istrstream is(buf.data(), buf.size());
Init(is);
}
#endif
void SymbolTable::Init(std::istream &is) {
std::string sym; std::string sym;
int32_t id; int32_t id;
while (is >> sym >> id) { while (is >> sym >> id) {

View File

@@ -8,6 +8,11 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
namespace sherpa_onnx { namespace sherpa_onnx {
/// It manages mapping between symbols and integer IDs. /// It manages mapping between symbols and integer IDs.
@@ -22,6 +27,10 @@ class SymbolTable {
/// Fields are separated by space(s). /// Fields are separated by space(s).
explicit SymbolTable(const std::string &filename); explicit SymbolTable(const std::string &filename);
#if __ANDROID_API__ >= 9
SymbolTable(AAssetManager *mgr, const std::string &filename);
#endif
/// Return a string representation of this symbol table /// Return a string representation of this symbol table
std::string ToString() const; std::string ToString() const;
@@ -36,6 +45,9 @@ class SymbolTable {
/// Return true if there is a given symbol in the symbol table. /// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const; bool contains(const std::string &sym) const;
private:
void Init(std::istream &is);
private: private:
std::unordered_map<std::string, int32_t> sym2id_; std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_; std::unordered_map<int32_t, std::string> id2sym_;

View File

@@ -20,7 +20,7 @@
#endif #endif
#if __ANDROID_API__ >= 8 #if __ANDROID_API__ >= 8
#include <android/log.h> #include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \ #define SHERPA_ONNX_LOGE(...) \
do { \ do { \
fprintf(stderr, ##__VA_ARGS__); \ fprintf(stderr, ##__VA_ARGS__); \