diff --git a/build-android-arm64-v8a.sh b/build-android-arm64-v8a.sh new file mode 100755 index 00000000..6e9daa02 --- /dev/null +++ b/build-android-arm64-v8a.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -ex + +dir=build-android-arm64-v8a + +mkdir -p $dir +cd $dir + +# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android +# (optional) remove the hardcoded debug flag in Android NDK android-ndk +# issue: https://github.com/android/ndk/issues/243 +# +# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23 +# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23 +# +# delete "-g" line +# +# list(APPEND ANDROID_COMPILER_FLAGS +# -g +# -DANDROID + + +if [ -z $ANDROID_NDK ]; then + ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669 + # or use + # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk + # + # Inside the $ANDROID_NDK directory, you can find a binary ndk-build + # and some other files like the file "build/cmake/android.toolchain.cmake" + + if [ ! -d $ANDROID_NDK ]; then + # For macOS, I have installed Android Studio, select the menu + # Tools -> SDK manager -> Android SDK + # and set "Android SDK location" to /Users/fangjun/software/my-android + ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670 + fi +fi + +if [ ! -d $ANDROID_NDK ]; then + echo Please set the environment variable ANDROID_NDK before you run this script + exit 1 +fi + +echo "ANDROID_NDK: $ANDROID_NDK" +sleep 1 + +cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_PLATFORM=android-21 .. +# make VERBOSE=1 -j4 +make -j4 +make install/strip + diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index cb4a7aa3..2b226316 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -1,39 +1,39 @@ function(download_onnxruntime) include(FetchContent) - if(CMAKE_SYSTEM_NAME STREQUAL Linux) - if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) - # For embedded systems - set(possible_file_locations - $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz - ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz - /tmp/onnxruntime-linux-aarch64-1.14.0.tgz - /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz - ) - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") - set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - else() - # If you don't have access to the Internet, - # please pre-download onnxruntime - set(possible_file_locations - $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz - ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz - /tmp/onnxruntime-linux-x64-1.14.0.tgz - /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz - ) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) + # For embedded systems + set(possible_file_locations + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz + ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz + /tmp/onnxruntime-linux-aarch64-1.14.0.tgz + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz + ) + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") + set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") + elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64) + # If you don't have access to the Internet, + # please pre-download onnxruntime + set(possible_file_locations + $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz + ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz + /tmp/onnxruntime-linux-x64-1.14.0.tgz + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz + ) - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") - set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") - # After downloading, it contains: - # ./lib/libonnxruntime.so.1.14.0 - # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 - # - # ./include - # It contains all the needed header files - endif() + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") + set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") + # After downloading, it contains: + # ./lib/libonnxruntime.so.1.14.0 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 + # + # ./include + # It contains all the needed header files elseif(APPLE) # If you don't have access to the Internet, # please pre-download onnxruntime @@ -69,6 +69,8 @@ function(download_onnxruntime) # ./include # It contains all the needed header files else() + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") endif() @@ -91,11 +93,15 @@ function(download_onnxruntime) endif() message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") - find_library(location_onnxruntime onnxruntime - PATHS - "${onnxruntime_SOURCE_DIR}/lib" - NO_CMAKE_SYSTEM_PATH - ) + if(ANDROID) + set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so) + else() + find_library(location_onnxruntime onnxruntime + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH + ) + endif() message(STATUS "location_onnxruntime: ${location_onnxruntime}") diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 6663262c..8058346c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -26,6 +26,10 @@ endif() add_library(sherpa-onnx-core ${sources}) +if(ANDROID_NDK) + target_link_libraries(sherpa-onnx-core android log) +endif() + target_link_libraries(sherpa-onnx-core onnxruntime kaldi-native-fbank-core diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 0e29a3c8..7d32efed 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -12,6 +12,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( sess_opts_.SetIntraOpNumThreads(config.num_threads); sess_opts_.SetInterOpNumThreads(config.num_threads); - InitEncoder(config.encoder_filename); - InitDecoder(config.decoder_filename); - InitJoiner(config.joiner_filename); + { + auto buf = ReadFile(config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } } -void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { - encoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +#if __ANDROID_API__ >= 9 +OnlineLstmTransducerModel::OnlineLstmTransducerModel( + AAssetManager *mgr, const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + { + auto buf = ReadFile(mgr, config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} +#endif + +void OnlineLstmTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(encoder_sess_.get(), &encoder_input_names_, &encoder_input_names_ptr_); @@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); } -void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { - decoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +void OnlineLstmTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(decoder_sess_.get(), &decoder_input_names_, &decoder_input_names_ptr_); @@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); } -void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { - joiner_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +void OnlineLstmTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(joiner_sess_.get(), &joiner_input_names_, &joiner_input_names_ptr_); diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index c24bfca4..a73912a7 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -9,6 +9,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { public: explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); +#if __ANDROID_API__ >= 9 + OnlineLstmTransducerModel(AAssetManager *mgr, + const OnlineTransducerModelConfig &config); +#endif + std::vector StackStates( const std::vector> &states) const override; @@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { OrtAllocator *Allocator() override { return allocator_; } private: - void InitEncoder(const std::string &encoder_filename); - void InitDecoder(const std::string &decoder_filename); - void InitJoiner(const std::string &joiner_filename); + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); private: Ort::Env env_; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 6292eb22..fe9f7baf 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -55,6 +55,17 @@ class OnlineRecognizer::Impl { std::make_unique(model_.get()); } +#if __ANDROID_API__ >= 9 + explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(mgr, config.model_config)), + sym_(mgr, config.tokens), + endpoint_(config_.endpoint_config) { + decoder_ = + std::make_unique(model_.get()); + } +#endif + std::unique_ptr CreateStream() const { auto stream = std::make_unique(config_.feat_config); stream->SetResult(decoder_->GetEmptyResult()); @@ -156,6 +167,13 @@ class OnlineRecognizer::Impl { OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr, + const OnlineRecognizerConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + OnlineRecognizer::~OnlineRecognizer() = default; std::unique_ptr OnlineRecognizer::CreateStream() const { diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 5066ee25..e057795d 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -45,6 +50,11 @@ struct OnlineRecognizerConfig { class OnlineRecognizer { public: explicit OnlineRecognizer(const OnlineRecognizerConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + ~OnlineRecognizer(); /// Create a stream for decoding. diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 22f58199..effbf558 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -3,6 +3,11 @@ // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-transducer-model.h" +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include #include #include @@ -18,15 +23,16 @@ enum class ModelType { kUnkown, }; -static ModelType GetModelType(const OnlineTransducerModelConfig &config) { +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING); Ort::SessionOptions sess_opts; - auto sess = std::make_unique( - env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); + auto sess = std::make_unique(env, model_data, model_data_length, + sess_opts); Ort::ModelMetadata meta_data = sess->GetModelMetadata(); - if (config.debug) { + if (debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); fprintf(stderr, "%s\n", os.str().c_str()); @@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { std::unique_ptr OnlineTransducerModel::Create( const OnlineTransducerModelConfig &config) { - auto model_type = GetModelType(config); + auto buffer = ReadFile(config.encoder_filename); + + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); switch (model_type) { case ModelType::kLstm: @@ -67,4 +75,24 @@ std::unique_ptr OnlineTransducerModel::Create( return nullptr; } +#if __ANDROID_API__ >= 9 +std::unique_ptr OnlineTransducerModel::Create( + AAssetManager *mgr, const OnlineTransducerModelConfig &config) { + auto buffer = ReadFile(mgr, config.encoder_filename); + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + + switch (model_type) { + case ModelType::kLstm: + return std::make_unique(mgr, config); + case ModelType::kZipformer: + return std::make_unique(mgr, config); + case ModelType::kUnkown: + return nullptr; + } + + // unreachable code + return nullptr; +} +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index baed186a..2757024c 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-transducer-model-config.h" @@ -22,6 +27,11 @@ class OnlineTransducerModel { static std::unique_ptr Create( const OnlineTransducerModelConfig &config); +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OnlineTransducerModelConfig &config); +#endif + /** Stack a list of individual states into a batch. * * It is the inverse operation of `UnStackStates`. diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 2a675fc0..7038274e 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -13,6 +13,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" @@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( sess_opts_.SetIntraOpNumThreads(config.num_threads); sess_opts_.SetInterOpNumThreads(config.num_threads); - InitEncoder(config.encoder_filename); - InitDecoder(config.decoder_filename); - InitJoiner(config.joiner_filename); + { + auto buf = ReadFile(config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } } -void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { - encoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +#if __ANDROID_API__ >= 9 +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + AAssetManager *mgr, const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + { + auto buf = ReadFile(mgr, config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} +#endif + +void OnlineZipformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(encoder_sess_.get(), &encoder_input_names_, &encoder_input_names_ptr_); @@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { } } -void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { - decoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +void OnlineZipformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(decoder_sess_.get(), &decoder_input_names_, &decoder_input_names_ptr_); @@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); } -void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) { - joiner_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); +void OnlineZipformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); GetInputNames(joiner_sess_.get(), &joiner_input_names_, &joiner_input_names_ptr_); diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h index 779ac288..02a9742d 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -9,6 +9,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { explicit OnlineZipformerTransducerModel( const OnlineTransducerModelConfig &config); +#if __ANDROID_API__ >= 9 + OnlineZipformerTransducerModel(AAssetManager *mgr, + const OnlineTransducerModelConfig &config); +#endif + std::vector StackStates( const std::vector> &states) const override; @@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { OrtAllocator *Allocator() override { return allocator_; } private: - void InitEncoder(const std::string &encoder_filename); - void InitDecoder(const std::string &decoder_filename); - void InitJoiner(const std::string &joiner_filename); + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); private: Ort::Env env_; diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index fd23552d..5d3324bf 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -3,9 +3,16 @@ // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/onnx-utils.h" +#include #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#include "android/log.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { @@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) { fprintf(stderr, "\n"); } +std::vector ReadFile(const std::string &filename) { + std::ifstream input(filename, std::ios::binary); + std::vector buffer(std::istreambuf_iterator(input), {}); + return buffer; +} + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { + AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); + if (!asset) { + __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx", + "Read binary file: Load %s failed", filename.c_str()); + exit(-1); + } + + auto p = reinterpret_cast(AAsset_getBuffer(asset)); + size_t asset_length = AAsset_getLength(asset); + + AAsset_close(asset); + + std::vector buffer(p, p + asset_length); + + return buffer; +} +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 8efcefd1..a00414d0 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -14,6 +14,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { @@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) { std::fill(p, p + n, value); } +std::vector ReadFile(const std::string &filename); + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename); +#endif + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 3d4f2b78..6f18bdad 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -7,11 +7,32 @@ #include #include #include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif namespace sherpa_onnx { SymbolTable::SymbolTable(const std::string &filename) { std::ifstream is(filename); + Init(is); +} + +#if __ANDROID_API__ >= 9 +SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { + auto buf = ReadFile(mgr, filename); + + std::istrstream is(buf.data(), buf.size()); + Init(is); +} +#endif + +void SymbolTable::Init(std::istream &is) { std::string sym; int32_t id; while (is >> sym >> id) { diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 0e1b74a9..103e0f27 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + namespace sherpa_onnx { /// It manages mapping between symbols and integer IDs. @@ -22,6 +27,10 @@ class SymbolTable { /// Fields are separated by space(s). explicit SymbolTable(const std::string &filename); +#if __ANDROID_API__ >= 9 + SymbolTable(AAssetManager *mgr, const std::string &filename); +#endif + /// Return a string representation of this symbol table std::string ToString() const; @@ -36,6 +45,9 @@ class SymbolTable { /// Return true if there is a given symbol in the symbol table. bool contains(const std::string &sym) const; + private: + void Init(std::istream &is); + private: std::unordered_map sym2id_; std::unordered_map id2sym_; diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index bc3cee7b..98052a5c 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -20,7 +20,7 @@ #endif #if __ANDROID_API__ >= 8 -#include +#include "android/log.h" #define SHERPA_ONNX_LOGE(...) \ do { \ fprintf(stderr, ##__VA_ARGS__); \