Add build script for Android armv8a (#58)
This commit is contained in:
61
build-android-arm64-v8a.sh
Executable file
61
build-android-arm64-v8a.sh
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
dir=build-android-arm64-v8a
|
||||||
|
|
||||||
|
mkdir -p $dir
|
||||||
|
cd $dir
|
||||||
|
|
||||||
|
# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
|
||||||
|
# (optional) remove the hardcoded debug flag in Android NDK android-ndk
|
||||||
|
# issue: https://github.com/android/ndk/issues/243
|
||||||
|
#
|
||||||
|
# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23
|
||||||
|
# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23
|
||||||
|
#
|
||||||
|
# delete "-g" line
|
||||||
|
#
|
||||||
|
# list(APPEND ANDROID_COMPILER_FLAGS
|
||||||
|
# -g
|
||||||
|
# -DANDROID
|
||||||
|
|
||||||
|
|
||||||
|
if [ -z $ANDROID_NDK ]; then
|
||||||
|
ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
|
||||||
|
# or use
|
||||||
|
# ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
|
||||||
|
#
|
||||||
|
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
|
||||||
|
# and some other files like the file "build/cmake/android.toolchain.cmake"
|
||||||
|
|
||||||
|
if [ ! -d $ANDROID_NDK ]; then
|
||||||
|
# For macOS, I have installed Android Studio, select the menu
|
||||||
|
# Tools -> SDK manager -> Android SDK
|
||||||
|
# and set "Android SDK location" to /Users/fangjun/software/my-android
|
||||||
|
ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $ANDROID_NDK ]; then
|
||||||
|
echo Please set the environment variable ANDROID_NDK before you run this script
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "ANDROID_NDK: $ANDROID_NDK"
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_JNI=ON \
|
||||||
|
-DCMAKE_INSTALL_PREFIX=./install \
|
||||||
|
-DANDROID_ABI="arm64-v8a" \
|
||||||
|
-DANDROID_PLATFORM=android-21 ..
|
||||||
|
# make VERBOSE=1 -j4
|
||||||
|
make -j4
|
||||||
|
make install/strip
|
||||||
|
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
function(download_onnxruntime)
|
function(download_onnxruntime)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME STREQUAL Linux)
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
|
||||||
# For embedded systems
|
# For embedded systems
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
@@ -13,8 +15,7 @@ function(download_onnxruntime)
|
|||||||
)
|
)
|
||||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
|
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
|
||||||
set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
|
set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
|
||||||
|
elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
|
||||||
else()
|
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download onnxruntime
|
# please pre-download onnxruntime
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
@@ -33,7 +34,6 @@ function(download_onnxruntime)
|
|||||||
#
|
#
|
||||||
# ./include
|
# ./include
|
||||||
# It contains all the needed header files
|
# It contains all the needed header files
|
||||||
endif()
|
|
||||||
elseif(APPLE)
|
elseif(APPLE)
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download onnxruntime
|
# please pre-download onnxruntime
|
||||||
@@ -69,6 +69,8 @@ function(download_onnxruntime)
|
|||||||
# ./include
|
# ./include
|
||||||
# It contains all the needed header files
|
# It contains all the needed header files
|
||||||
else()
|
else()
|
||||||
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
|
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -91,11 +93,15 @@ function(download_onnxruntime)
|
|||||||
endif()
|
endif()
|
||||||
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so)
|
||||||
|
else()
|
||||||
find_library(location_onnxruntime onnxruntime
|
find_library(location_onnxruntime onnxruntime
|
||||||
PATHS
|
PATHS
|
||||||
"${onnxruntime_SOURCE_DIR}/lib"
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
NO_CMAKE_SYSTEM_PATH
|
NO_CMAKE_SYSTEM_PATH
|
||||||
)
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ endif()
|
|||||||
|
|
||||||
add_library(sherpa-onnx-core ${sources})
|
add_library(sherpa-onnx-core ${sources})
|
||||||
|
|
||||||
|
if(ANDROID_NDK)
|
||||||
|
target_link_libraries(sherpa-onnx-core android log)
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_libraries(sherpa-onnx-core
|
target_link_libraries(sherpa-onnx-core
|
||||||
onnxruntime
|
onnxruntime
|
||||||
kaldi-native-fbank-core
|
kaldi-native-fbank-core
|
||||||
|
|||||||
@@ -12,6 +12,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/cat.h"
|
#include "sherpa-onnx/csrc/cat.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
|||||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||||
|
|
||||||
InitEncoder(config.encoder_filename);
|
{
|
||||||
InitDecoder(config.decoder_filename);
|
auto buf = ReadFile(config.encoder_filename);
|
||||||
InitJoiner(config.joiner_filename);
|
InitEncoder(buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
{
|
||||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
auto buf = ReadFile(config.decoder_filename);
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||||
|
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||||
|
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
config_(config),
|
||||||
|
sess_opts_{},
|
||||||
|
allocator_{} {
|
||||||
|
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||||
|
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void OnlineLstmTransducerModel::InitEncoder(void *model_data,
|
||||||
|
size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
&encoder_input_names_ptr_);
|
&encoder_input_names_ptr_);
|
||||||
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
|||||||
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
|
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
void OnlineLstmTransducerModel::InitDecoder(void *model_data,
|
||||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
size_t model_data_length) {
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
&decoder_input_names_ptr_);
|
&decoder_input_names_ptr_);
|
||||||
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
|||||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
void OnlineLstmTransducerModel::InitJoiner(void *model_data,
|
||||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
size_t model_data_length) {
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||||
&joiner_input_names_ptr_);
|
&joiner_input_names_ptr_);
|
||||||
|
|||||||
@@ -9,6 +9,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
|||||||
public:
|
public:
|
||||||
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineLstmTransducerModel(AAssetManager *mgr,
|
||||||
|
const OnlineTransducerModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<Ort::Value> StackStates(
|
std::vector<Ort::Value> StackStates(
|
||||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||||
|
|
||||||
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
|||||||
OrtAllocator *Allocator() override { return allocator_; }
|
OrtAllocator *Allocator() override { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitEncoder(const std::string &encoder_filename);
|
void InitEncoder(void *model_data, size_t model_data_length);
|
||||||
void InitDecoder(const std::string &decoder_filename);
|
void InitDecoder(void *model_data, size_t model_data_length);
|
||||||
void InitJoiner(const std::string &joiner_filename);
|
void InitJoiner(void *model_data, size_t model_data_length);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Ort::Env env_;
|
Ort::Env env_;
|
||||||
|
|||||||
@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
|
|||||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||||
|
sym_(mgr, config.tokens),
|
||||||
|
endpoint_(config_.endpoint_config) {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||||
stream->SetResult(decoder_->GetEmptyResult());
|
stream->SetResult(decoder_->GetEmptyResult());
|
||||||
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
|
|||||||
|
|
||||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
|
||||||
|
const OnlineRecognizerConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
OnlineRecognizer::~OnlineRecognizer() = default;
|
OnlineRecognizer::~OnlineRecognizer() = default;
|
||||||
|
|
||||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||||
|
|||||||
@@ -8,6 +8,11 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/endpoint.h"
|
#include "sherpa-onnx/csrc/endpoint.h"
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
#include "sherpa-onnx/csrc/online-stream.h"
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
|
|||||||
class OnlineRecognizer {
|
class OnlineRecognizer {
|
||||||
public:
|
public:
|
||||||
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
~OnlineRecognizer();
|
~OnlineRecognizer();
|
||||||
|
|
||||||
/// Create a stream for decoding.
|
/// Create a stream for decoding.
|
||||||
|
|||||||
@@ -3,6 +3,11 @@
|
|||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -18,15 +23,16 @@ enum class ModelType {
|
|||||||
kUnkown,
|
kUnkown,
|
||||||
};
|
};
|
||||||
|
|
||||||
static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||||
|
bool debug) {
|
||||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
Ort::SessionOptions sess_opts;
|
Ort::SessionOptions sess_opts;
|
||||||
|
|
||||||
auto sess = std::make_unique<Ort::Session>(
|
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||||
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
|
sess_opts);
|
||||||
|
|
||||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||||
if (config.debug) {
|
if (debug) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
PrintModelMetadata(os, meta_data);
|
PrintModelMetadata(os, meta_data);
|
||||||
fprintf(stderr, "%s\n", os.str().c_str());
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
|||||||
|
|
||||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||||
const OnlineTransducerModelConfig &config) {
|
const OnlineTransducerModelConfig &config) {
|
||||||
auto model_type = GetModelType(config);
|
auto buffer = ReadFile(config.encoder_filename);
|
||||||
|
|
||||||
|
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
|
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case ModelType::kLstm:
|
case ModelType::kLstm:
|
||||||
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||||
|
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
|
||||||
|
auto buffer = ReadFile(mgr, config.encoder_filename);
|
||||||
|
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
|
|
||||||
|
switch (model_type) {
|
||||||
|
case ModelType::kLstm:
|
||||||
|
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
|
||||||
|
case ModelType::kZipformer:
|
||||||
|
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||||
|
case ModelType::kUnkown:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable code
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -8,6 +8,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
@@ -22,6 +27,11 @@ class OnlineTransducerModel {
|
|||||||
static std::unique_ptr<OnlineTransducerModel> Create(
|
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||||
const OnlineTransducerModelConfig &config);
|
const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||||
|
AAssetManager *mgr, const OnlineTransducerModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
/** Stack a list of individual states into a batch.
|
/** Stack a list of individual states into a batch.
|
||||||
*
|
*
|
||||||
* It is the inverse operation of `UnStackStates`.
|
* It is the inverse operation of `UnStackStates`.
|
||||||
|
|||||||
@@ -13,6 +13,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/cat.h"
|
#include "sherpa-onnx/csrc/cat.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
|||||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||||
|
|
||||||
InitEncoder(config.encoder_filename);
|
{
|
||||||
InitDecoder(config.decoder_filename);
|
auto buf = ReadFile(config.encoder_filename);
|
||||||
InitJoiner(config.joiner_filename);
|
InitEncoder(buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
{
|
||||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
auto buf = ReadFile(config.decoder_filename);
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||||
|
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||||
|
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
config_(config),
|
||||||
|
sess_opts_{},
|
||||||
|
allocator_{} {
|
||||||
|
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||||
|
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||||
|
InitJoiner(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
|
||||||
|
size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
&encoder_input_names_ptr_);
|
&encoder_input_names_ptr_);
|
||||||
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
|
||||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
size_t model_data_length) {
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
&decoder_input_names_ptr_);
|
&decoder_input_names_ptr_);
|
||||||
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
|
|||||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {
|
void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
|
||||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
size_t model_data_length) {
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||||
|
model_data_length, sess_opts_);
|
||||||
|
|
||||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||||
&joiner_input_names_ptr_);
|
&joiner_input_names_ptr_);
|
||||||
|
|||||||
@@ -9,6 +9,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
|||||||
explicit OnlineZipformerTransducerModel(
|
explicit OnlineZipformerTransducerModel(
|
||||||
const OnlineTransducerModelConfig &config);
|
const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineZipformerTransducerModel(AAssetManager *mgr,
|
||||||
|
const OnlineTransducerModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<Ort::Value> StackStates(
|
std::vector<Ort::Value> StackStates(
|
||||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||||
|
|
||||||
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
|||||||
OrtAllocator *Allocator() override { return allocator_; }
|
OrtAllocator *Allocator() override { return allocator_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitEncoder(const std::string &encoder_filename);
|
void InitEncoder(void *model_data, size_t model_data_length);
|
||||||
void InitDecoder(const std::string &decoder_filename);
|
void InitDecoder(void *model_data, size_t model_data_length);
|
||||||
void InitJoiner(const std::string &joiner_filename);
|
void InitJoiner(void *model_data, size_t model_data_length);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Ort::Env env_;
|
Ort::Env env_;
|
||||||
|
|||||||
@@ -3,9 +3,16 @@
|
|||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#include "android/log.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<char> ReadFile(const std::string &filename) {
|
||||||
|
std::ifstream input(filename, std::ios::binary);
|
||||||
|
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||||
|
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
|
||||||
|
if (!asset) {
|
||||||
|
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
|
||||||
|
"Read binary file: Load %s failed", filename.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||||
|
size_t asset_length = AAsset_getLength(asset);
|
||||||
|
|
||||||
|
AAsset_close(asset);
|
||||||
|
|
||||||
|
std::vector<char> buffer(p, p + asset_length);
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -14,6 +14,11 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
|
|||||||
std::fill(p, p + n, value);
|
std::fill(p, p + n, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<char> ReadFile(const std::string &filename);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
|||||||
@@ -7,11 +7,32 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <strstream>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
SymbolTable::SymbolTable(const std::string &filename) {
|
SymbolTable::SymbolTable(const std::string &filename) {
|
||||||
std::ifstream is(filename);
|
std::ifstream is(filename);
|
||||||
|
Init(is);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
|
||||||
|
auto buf = ReadFile(mgr, filename);
|
||||||
|
|
||||||
|
std::istrstream is(buf.data(), buf.size());
|
||||||
|
Init(is);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void SymbolTable::Init(std::istream &is) {
|
||||||
std::string sym;
|
std::string sym;
|
||||||
int32_t id;
|
int32_t id;
|
||||||
while (is >> sym >> id) {
|
while (is >> sym >> id) {
|
||||||
|
|||||||
@@ -8,6 +8,11 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
/// It manages mapping between symbols and integer IDs.
|
/// It manages mapping between symbols and integer IDs.
|
||||||
@@ -22,6 +27,10 @@ class SymbolTable {
|
|||||||
/// Fields are separated by space(s).
|
/// Fields are separated by space(s).
|
||||||
explicit SymbolTable(const std::string &filename);
|
explicit SymbolTable(const std::string &filename);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
SymbolTable(AAssetManager *mgr, const std::string &filename);
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Return a string representation of this symbol table
|
/// Return a string representation of this symbol table
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
@@ -36,6 +45,9 @@ class SymbolTable {
|
|||||||
/// Return true if there is a given symbol in the symbol table.
|
/// Return true if there is a given symbol in the symbol table.
|
||||||
bool contains(const std::string &sym) const;
|
bool contains(const std::string &sym) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(std::istream &is);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, int32_t> sym2id_;
|
std::unordered_map<std::string, int32_t> sym2id_;
|
||||||
std::unordered_map<int32_t, std::string> id2sym_;
|
std::unordered_map<int32_t, std::string> id2sym_;
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 8
|
#if __ANDROID_API__ >= 8
|
||||||
#include <android/log.h>
|
#include "android/log.h"
|
||||||
#define SHERPA_ONNX_LOGE(...) \
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
do { \
|
do { \
|
||||||
fprintf(stderr, ##__VA_ARGS__); \
|
fprintf(stderr, ##__VA_ARGS__); \
|
||||||
|
|||||||
Reference in New Issue
Block a user