Add build script for Android armv8a (#58)
This commit is contained in:
61
build-android-arm64-v8a.sh
Executable file
61
build-android-arm64-v8a.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
dir=build-android-arm64-v8a
|
||||
|
||||
mkdir -p $dir
|
||||
cd $dir
|
||||
|
||||
# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
|
||||
# (optional) remove the hardcoded debug flag in Android NDK android-ndk
|
||||
# issue: https://github.com/android/ndk/issues/243
|
||||
#
|
||||
# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23
|
||||
# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23
|
||||
#
|
||||
# delete "-g" line
|
||||
#
|
||||
# list(APPEND ANDROID_COMPILER_FLAGS
|
||||
# -g
|
||||
# -DANDROID
|
||||
|
||||
|
||||
if [ -z $ANDROID_NDK ]; then
|
||||
ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
|
||||
# or use
|
||||
# ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
|
||||
#
|
||||
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
|
||||
# and some other files like the file "build/cmake/android.toolchain.cmake"
|
||||
|
||||
if [ ! -d $ANDROID_NDK ]; then
|
||||
# For macOS, I have installed Android Studio, select the menu
|
||||
# Tools -> SDK manager -> Android SDK
|
||||
# and set "Android SDK location" to /Users/fangjun/software/my-android
|
||||
ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -d $ANDROID_NDK ]; then
|
||||
echo Please set the environment variable ANDROID_NDK before you run this script
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ANDROID_NDK: $ANDROID_NDK"
|
||||
sleep 1
|
||||
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_JNI=ON \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DANDROID_ABI="arm64-v8a" \
|
||||
-DANDROID_PLATFORM=android-21 ..
|
||||
# make VERBOSE=1 -j4
|
||||
make -j4
|
||||
make install/strip
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
function(download_onnxruntime)
|
||||
include(FetchContent)
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||
# For embedded systems
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
/tmp/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
)
|
||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
|
||||
set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
|
||||
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
|
||||
else()
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download onnxruntime
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz
|
||||
${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||
${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||
/tmp/onnxruntime-linux-x64-1.14.0.tgz
|
||||
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
|
||||
)
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||
# For embedded systems
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
/tmp/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz
|
||||
)
|
||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
|
||||
set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
|
||||
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download onnxruntime
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz
|
||||
${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||
${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||
/tmp/onnxruntime-linux-x64-1.14.0.tgz
|
||||
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
|
||||
)
|
||||
|
||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
|
||||
set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090")
|
||||
# After downloading, it contains:
|
||||
# ./lib/libonnxruntime.so.1.14.0
|
||||
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0
|
||||
#
|
||||
# ./include
|
||||
# It contains all the needed header files
|
||||
endif()
|
||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
|
||||
set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090")
|
||||
# After downloading, it contains:
|
||||
# ./lib/libonnxruntime.so.1.14.0
|
||||
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0
|
||||
#
|
||||
# ./include
|
||||
# It contains all the needed header files
|
||||
elseif(APPLE)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download onnxruntime
|
||||
@@ -69,6 +69,8 @@ function(download_onnxruntime)
|
||||
# ./include
|
||||
# It contains all the needed header files
|
||||
else()
|
||||
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
|
||||
endif()
|
||||
|
||||
@@ -91,11 +93,15 @@ function(download_onnxruntime)
|
||||
endif()
|
||||
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||
|
||||
find_library(location_onnxruntime onnxruntime
|
||||
PATHS
|
||||
"${onnxruntime_SOURCE_DIR}/lib"
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
)
|
||||
if(ANDROID)
|
||||
set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so)
|
||||
else()
|
||||
find_library(location_onnxruntime onnxruntime
|
||||
PATHS
|
||||
"${onnxruntime_SOURCE_DIR}/lib"
|
||||
NO_CMAKE_SYSTEM_PATH
|
||||
)
|
||||
endif()
|
||||
|
||||
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||
|
||||
|
||||
@@ -26,6 +26,10 @@ endif()
|
||||
|
||||
add_library(sherpa-onnx-core ${sources})
|
||||
|
||||
if(ANDROID_NDK)
|
||||
target_link_libraries(sherpa-onnx-core android log)
|
||||
endif()
|
||||
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
|
||||
@@ -12,6 +12,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
InitEncoder(config.encoder_filename);
|
||||
InitDecoder(config.decoder_filename);
|
||||
InitJoiner(config.joiner_filename);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void OnlineLstmTransducerModel::InitEncoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineLstmTransducerModel::InitDecoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineLstmTransducerModel::InitJoiner(void *model_data,
|
||||
size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
public:
|
||||
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineLstmTransducerModel(AAssetManager *mgr,
|
||||
const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
void InitEncoder(void *model_data, size_t model_data_length);
|
||||
void InitDecoder(void *model_data, size_t model_data_length);
|
||||
void InitJoiner(void *model_data, size_t model_data_length);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
|
||||
@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
|
||||
const OnlineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OnlineRecognizer::~OnlineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
|
||||
class OnlineRecognizer {
|
||||
public:
|
||||
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
||||
#endif
|
||||
|
||||
~OnlineRecognizer();
|
||||
|
||||
/// Create a stream for decoding.
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -18,15 +23,16 @@ enum class ModelType {
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(
|
||||
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (config.debug) {
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
const OnlineTransducerModelConfig &config) {
|
||||
auto model_type = GetModelType(config);
|
||||
auto buffer = ReadFile(config.encoder_filename);
|
||||
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kLstm:
|
||||
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
|
||||
auto buffer = ReadFile(mgr, config.encoder_filename);
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kLstm:
|
||||
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
|
||||
case ModelType::kZipformer:
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
@@ -22,6 +27,11 @@ class OnlineTransducerModel {
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
/** Stack a list of individual states into a batch.
|
||||
*
|
||||
* It is the inverse operation of `UnStackStates`.
|
||||
|
||||
@@ -13,6 +13,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
InitEncoder(config.encoder_filename);
|
||||
InitDecoder(config.decoder_filename);
|
||||
InitJoiner(config.joiner_filename);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||
void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
|
||||
size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
explicit OnlineZipformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineZipformerTransducerModel(AAssetManager *mgr,
|
||||
const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
void InitEncoder(void *model_data, size_t model_data_length);
|
||||
void InitDecoder(void *model_data, size_t model_data_length);
|
||||
void InitJoiner(void *model_data, size_t model_data_length);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
|
||||
@@ -3,9 +3,16 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#include "android/log.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename) {
|
||||
std::ifstream input(filename, std::ios::binary);
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||
return buffer;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
|
||||
"Read binary file: Load %s failed", filename.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
|
||||
AAsset_close(asset);
|
||||
|
||||
std::vector<char> buffer(p, p + asset_length);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -14,6 +14,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
|
||||
std::fill(p, p + n, value);
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -7,11 +7,32 @@
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <strstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
SymbolTable::SymbolTable(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
|
||||
auto buf = ReadFile(mgr, filename);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
Init(is);
|
||||
}
|
||||
#endif
|
||||
|
||||
void SymbolTable::Init(std::istream &is) {
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// It manages mapping between symbols and integer IDs.
|
||||
@@ -22,6 +27,10 @@ class SymbolTable {
|
||||
/// Fields are separated by space(s).
|
||||
explicit SymbolTable(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
/// Return a string representation of this symbol table
|
||||
std::string ToString() const;
|
||||
|
||||
@@ -36,6 +45,9 @@ class SymbolTable {
|
||||
/// Return true if there is a given symbol in the symbol table.
|
||||
bool contains(const std::string &sym) const;
|
||||
|
||||
private:
|
||||
void Init(std::istream &is);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, int32_t> sym2id_;
|
||||
std::unordered_map<int32_t, std::string> id2sym_;
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#endif
|
||||
|
||||
#if __ANDROID_API__ >= 8
|
||||
#include <android/log.h>
|
||||
#include "android/log.h"
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
|
||||
Reference in New Issue
Block a user