Support Android (#59)
This commit is contained in:
@@ -5,6 +5,23 @@
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
|
||||
#define SHERPA_ONNX_CSRC_MACROS_H_
|
||||
#include <stdio.h>
|
||||
|
||||
#if __ANDROID_API__ >= 8
|
||||
#include "android/log.h"
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
#else
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||
do { \
|
||||
|
||||
@@ -37,7 +37,6 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "OnlineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
|
||||
|
||||
@@ -49,7 +48,7 @@ class OnlineRecognizer::Impl {
|
||||
explicit Impl(const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.tokens),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
@@ -59,7 +58,7 @@ class OnlineRecognizer::Impl {
|
||||
explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.tokens),
|
||||
sym_(mgr, config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
|
||||
@@ -27,7 +27,6 @@ struct OnlineRecognizerResult {
|
||||
struct OnlineRecognizerConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineTransducerModelConfig model_config;
|
||||
std::string tokens;
|
||||
EndpointConfig endpoint_config;
|
||||
bool enable_endpoint;
|
||||
|
||||
@@ -35,12 +34,10 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const std::string &tokens,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
tokens(tokens),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint) {}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ std::string OnlineTransducerModelConfig::ToString() const {
|
||||
os << "encoder_filename=\"" << encoder_filename << "\", ";
|
||||
os << "decoder_filename=\"" << decoder_filename << "\", ";
|
||||
os << "joiner_filename=\"" << joiner_filename << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ struct OnlineTransducerModelConfig {
|
||||
std::string encoder_filename;
|
||||
std::string decoder_filename;
|
||||
std::string joiner_filename;
|
||||
std::string tokens;
|
||||
int32_t num_threads;
|
||||
bool debug = false;
|
||||
|
||||
@@ -19,10 +20,12 @@ struct OnlineTransducerModelConfig {
|
||||
OnlineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
int32_t num_threads, bool debug)
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
bool debug)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
|
||||
@@ -141,9 +141,8 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
|
||||
AAsset_close(asset);
|
||||
|
||||
std::vector<char> buffer(p, p + asset_length);
|
||||
AAsset_close(asset);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ as the device_name.
|
||||
|
||||
sherpa_onnx::OnlineRecognizerConfig config;
|
||||
|
||||
config.tokens = argv[1];
|
||||
config.model_config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
|
||||
@@ -58,7 +58,7 @@ for a list of pre-trained models to download.
|
||||
signal(SIGINT, Handler);
|
||||
|
||||
sherpa_onnx::OnlineRecognizerConfig config;
|
||||
config.tokens = argv[1];
|
||||
config.model_config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
|
||||
@@ -35,7 +35,7 @@ for a list of pre-trained models to download.
|
||||
|
||||
sherpa_onnx::OnlineRecognizerConfig config;
|
||||
|
||||
config.tokens = argv[1];
|
||||
config.model_config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
|
||||
@@ -19,23 +19,9 @@
|
||||
#include <fstream>
|
||||
#endif
|
||||
|
||||
#if __ANDROID_API__ >= 8
|
||||
#include "android/log.h"
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
#else
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
#define SHERPA_ONNX_EXTERN_C extern "C"
|
||||
@@ -160,14 +146,6 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
ans.endpoint_config.rule3.min_utterance_length =
|
||||
env->GetFloatField(rule3, fid);
|
||||
|
||||
//---------- tokens ----------
|
||||
|
||||
fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||
@@ -175,8 +153,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
jstring s = (jstring)env->GetObjectField(model_config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.encoder_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
@@ -192,6 +170,12 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
ans.model_config.joiner_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
@@ -226,7 +210,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
SHERPA_ONNX_LOGE("freed!");
|
||||
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
}
|
||||
|
||||
@@ -286,12 +269,9 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
std::vector<char> buffer(asset_length);
|
||||
AAsset_read(asset, buffer.data(), asset_length);
|
||||
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
|
||||
|
||||
std::istrstream is(buffer.data(), asset_length);
|
||||
std::istrstream is(buffer.data(), buffer.size());
|
||||
#else
|
||||
std::ifstream is(p_filename, std::ios::binary);
|
||||
#endif
|
||||
@@ -300,9 +280,6 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAsset_close(asset);
|
||||
#endif
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
|
||||
if (!is_ok) {
|
||||
|
||||
@@ -21,13 +21,12 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const std::string &,
|
||||
const EndpointConfig &, bool>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
|
||||
const OnlineTransducerModelConfig &, const EndpointConfig &,
|
||||
bool>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
|
||||
@@ -14,13 +14,14 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
|
||||
using PyClass = OnlineTransducerModelConfig;
|
||||
py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, int32_t, bool>(),
|
||||
const std::string &, const std::string &, int32_t, bool>(),
|
||||
py::arg("encoder_filename"), py::arg("decoder_filename"),
|
||||
py::arg("joiner_filename"), py::arg("num_threads"),
|
||||
py::arg("debug") = false)
|
||||
py::arg("joiner_filename"), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false)
|
||||
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
|
||||
.def_readwrite("decoder_filename", &PyClass::decoder_filename)
|
||||
.def_readwrite("joiner_filename", &PyClass::joiner_filename)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
|
||||
@@ -85,6 +85,7 @@ class OnlineRecognizer(object):
|
||||
encoder_filename=encoder,
|
||||
decoder_filename=decoder,
|
||||
joiner_filename=joiner,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@@ -102,7 +103,6 @@ class OnlineRecognizer(object):
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
tokens=tokens,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user