Support Android (#59)

This commit is contained in:
Fangjun Kuang
2023-02-24 13:57:03 +08:00
committed by GitHub
parent 5a5d029490
commit 9064b3f016
68 changed files with 1469 additions and 220 deletions

View File

@@ -5,6 +5,23 @@
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#include <stdio.h>
#if __ANDROID_API__ >= 8
#include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
} while (0)
#else
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
#endif
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
do { \

View File

@@ -37,7 +37,6 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
@@ -49,7 +48,7 @@ class OnlineRecognizer::Impl {
explicit Impl(const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.tokens),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
@@ -59,7 +58,7 @@ class OnlineRecognizer::Impl {
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.tokens),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());

View File

@@ -27,7 +27,6 @@ struct OnlineRecognizerResult {
struct OnlineRecognizerConfig {
FeatureExtractorConfig feat_config;
OnlineTransducerModelConfig model_config;
std::string tokens;
EndpointConfig endpoint_config;
bool enable_endpoint;
@@ -35,12 +34,10 @@ struct OnlineRecognizerConfig {
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const std::string &tokens,
const EndpointConfig &endpoint_config,
bool enable_endpoint)
: feat_config(feat_config),
model_config(model_config),
tokens(tokens),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}

View File

@@ -14,6 +14,7 @@ std::string OnlineTransducerModelConfig::ToString() const {
os << "encoder_filename=\"" << encoder_filename << "\", ";
os << "decoder_filename=\"" << decoder_filename << "\", ";
os << "joiner_filename=\"" << joiner_filename << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ")";

View File

@@ -12,6 +12,7 @@ struct OnlineTransducerModelConfig {
std::string encoder_filename;
std::string decoder_filename;
std::string joiner_filename;
std::string tokens;
int32_t num_threads;
bool debug = false;
@@ -19,10 +20,12 @@ struct OnlineTransducerModelConfig {
OnlineTransducerModelConfig(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
int32_t num_threads, bool debug)
const std::string &tokens, int32_t num_threads,
bool debug)
: encoder_filename(encoder_filename),
decoder_filename(decoder_filename),
joiner_filename(joiner_filename),
tokens(tokens),
num_threads(num_threads),
debug(debug) {}

View File

@@ -141,9 +141,8 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
size_t asset_length = AAsset_getLength(asset);
AAsset_close(asset);
std::vector<char> buffer(p, p + asset_length);
AAsset_close(asset);
return buffer;
}

View File

@@ -65,7 +65,7 @@ as the device_name.
sherpa_onnx::OnlineRecognizerConfig config;
config.tokens = argv[1];
config.model_config.tokens = argv[1];
config.model_config.debug = false;
config.model_config.encoder_filename = argv[2];

View File

@@ -58,7 +58,7 @@ for a list of pre-trained models to download.
signal(SIGINT, Handler);
sherpa_onnx::OnlineRecognizerConfig config;
config.tokens = argv[1];
config.model_config.tokens = argv[1];
config.model_config.debug = false;
config.model_config.encoder_filename = argv[2];

View File

@@ -35,7 +35,7 @@ for a list of pre-trained models to download.
sherpa_onnx::OnlineRecognizerConfig config;
config.tokens = argv[1];
config.model_config.tokens = argv[1];
config.model_config.debug = false;
config.model_config.encoder_filename = argv[2];

View File

@@ -19,23 +19,9 @@
#include <fstream>
#endif
#if __ANDROID_API__ >= 8
#include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
} while (0)
#else
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#define SHERPA_ONNX_EXTERN_C extern "C"
@@ -160,14 +146,6 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
ans.endpoint_config.rule3.min_utterance_length =
env->GetFloatField(rule3, fid);
//---------- tokens ----------
fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.tokens = p;
env->ReleaseStringUTFChars(s, p);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
@@ -175,8 +153,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
@@ -192,6 +170,12 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
ans.model_config.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
@@ -226,7 +210,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
SHERPA_ONNX_LOGE("freed!");
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
}
@@ -286,12 +269,9 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
return nullptr;
}
AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
size_t asset_length = AAsset_getLength(asset);
std::vector<char> buffer(asset_length);
AAsset_read(asset, buffer.data(), asset_length);
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
std::istrstream is(buffer.data(), asset_length);
std::istrstream is(buffer.data(), buffer.size());
#else
std::ifstream is(p_filename, std::ios::binary);
#endif
@@ -300,9 +280,6 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
std::vector<float> samples =
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
#if __ANDROID_API__ >= 9
AAsset_close(asset);
#endif
env->ReleaseStringUTFChars(filename, p_filename);
if (!is_ok) {

View File

@@ -21,13 +21,12 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const std::string &,
const EndpointConfig &, bool>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
const OnlineTransducerModelConfig &, const EndpointConfig &,
bool>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def("__str__", &PyClass::ToString);

View File

@@ -14,13 +14,14 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
using PyClass = OnlineTransducerModelConfig;
py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
.def(py::init<const std::string &, const std::string &,
const std::string &, int32_t, bool>(),
const std::string &, const std::string &, int32_t, bool>(),
py::arg("encoder_filename"), py::arg("decoder_filename"),
py::arg("joiner_filename"), py::arg("num_threads"),
py::arg("debug") = false)
py::arg("joiner_filename"), py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false)
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
.def_readwrite("decoder_filename", &PyClass::decoder_filename)
.def_readwrite("joiner_filename", &PyClass::joiner_filename)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def("__str__", &PyClass::ToString);

View File

@@ -85,6 +85,7 @@ class OnlineRecognizer(object):
encoder_filename=encoder,
decoder_filename=decoder,
joiner_filename=joiner,
tokens=tokens,
num_threads=num_threads,
)
@@ -102,7 +103,6 @@ class OnlineRecognizer(object):
recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
tokens=tokens,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
)