316 lines
10 KiB
C++
316 lines
10 KiB
C++
// sherpa-onnx/jni/jni.cc
|
|
//
|
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
|
// 2022 Pingfeng Luo
|
|
|
|
// TODO(fangjun): Add documentation to functions/methods in this file
|
|
// and also show how to use them with kotlin, possibly with java.
|
|
|
|
// If you use ndk, you can find "jni.h" inside
|
|
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
|
#include "jni.h" // NOLINT
|
|
|
|
#include <strstream>
|
|
|
|
#if __ANDROID_API__ >= 9
|
|
#include "android/asset_manager.h"
|
|
#include "android/asset_manager_jni.h"
|
|
#else
|
|
#include <fstream>
|
|
#endif
|
|
|
|
#if __ANDROID_API__ >= 8
|
|
#include <android/log.h>
|
|
#define SHERPA_ONNX_LOGE(...) \
|
|
do { \
|
|
fprintf(stderr, ##__VA_ARGS__); \
|
|
fprintf(stderr, "\n"); \
|
|
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
|
} while (0)
|
|
#else
|
|
#define SHERPA_ONNX_LOGE(...) \
|
|
do { \
|
|
fprintf(stderr, ##__VA_ARGS__); \
|
|
fprintf(stderr, "\n"); \
|
|
} while (0)
|
|
#endif
|
|
|
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
|
|
|
#define SHERPA_ONNX_EXTERN_C extern "C"
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
class SherpaOnnx {
|
|
public:
|
|
SherpaOnnx(
|
|
#if __ANDROID_API__ >= 9
|
|
AAssetManager *mgr,
|
|
#endif
|
|
const sherpa_onnx::OnlineRecognizerConfig &config)
|
|
: recognizer_(
|
|
#if __ANDROID_API__ >= 9
|
|
mgr,
|
|
#endif
|
|
config),
|
|
stream_(recognizer_.CreateStream()),
|
|
tail_padding_(16000 * 0.32, 0) {
|
|
}
|
|
|
|
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
|
|
stream_->AcceptWaveform(sample_rate, samples, n);
|
|
Decode();
|
|
}
|
|
|
|
void InputFinished() const {
|
|
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
|
|
stream_->InputFinished();
|
|
Decode();
|
|
}
|
|
|
|
const std::string GetText() const {
|
|
auto result = recognizer_.GetResult(stream_.get());
|
|
return result.text;
|
|
}
|
|
|
|
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
|
|
|
|
void Reset() const { return recognizer_.Reset(stream_.get()); }
|
|
|
|
private:
|
|
void Decode() const {
|
|
while (recognizer_.IsReady(stream_.get())) {
|
|
recognizer_.DecodeStream(stream_.get());
|
|
}
|
|
}
|
|
|
|
private:
|
|
sherpa_onnx::OnlineRecognizer recognizer_;
|
|
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
|
std::vector<float> tail_padding_;
|
|
};
|
|
|
|
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
|
OnlineRecognizerConfig ans;
|
|
|
|
jclass cls = env->GetObjectClass(config);
|
|
jfieldID fid;
|
|
|
|
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
|
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
|
|
|
//---------- feat config ----------
|
|
fid = env->GetFieldID(cls, "featConfig",
|
|
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
|
jobject feat_config = env->GetObjectField(config, fid);
|
|
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
|
|
|
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
|
|
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
|
|
|
|
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
|
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
|
|
|
//---------- enable endpoint ----------
|
|
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
|
|
ans.enable_endpoint = env->GetBooleanField(config, fid);
|
|
|
|
//---------- endpoint_config ----------
|
|
|
|
fid = env->GetFieldID(cls, "endpointConfig",
|
|
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
|
|
jobject endpoint_config = env->GetObjectField(config, fid);
|
|
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
|
|
|
|
fid = env->GetFieldID(endpoint_config_cls, "rule1",
|
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
|
jobject rule1 = env->GetObjectField(endpoint_config, fid);
|
|
jclass rule_class = env->GetObjectClass(rule1);
|
|
|
|
fid = env->GetFieldID(endpoint_config_cls, "rule2",
|
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
|
jobject rule2 = env->GetObjectField(endpoint_config, fid);
|
|
|
|
fid = env->GetFieldID(endpoint_config_cls, "rule3",
|
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
|
jobject rule3 = env->GetObjectField(endpoint_config, fid);
|
|
|
|
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
|
|
ans.endpoint_config.rule1.must_contain_nonsilence =
|
|
env->GetBooleanField(rule1, fid);
|
|
ans.endpoint_config.rule2.must_contain_nonsilence =
|
|
env->GetBooleanField(rule2, fid);
|
|
ans.endpoint_config.rule3.must_contain_nonsilence =
|
|
env->GetBooleanField(rule3, fid);
|
|
|
|
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
|
|
ans.endpoint_config.rule1.min_trailing_silence =
|
|
env->GetFloatField(rule1, fid);
|
|
ans.endpoint_config.rule2.min_trailing_silence =
|
|
env->GetFloatField(rule2, fid);
|
|
ans.endpoint_config.rule3.min_trailing_silence =
|
|
env->GetFloatField(rule3, fid);
|
|
|
|
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
|
|
ans.endpoint_config.rule1.min_utterance_length =
|
|
env->GetFloatField(rule1, fid);
|
|
ans.endpoint_config.rule2.min_utterance_length =
|
|
env->GetFloatField(rule2, fid);
|
|
ans.endpoint_config.rule3.min_utterance_length =
|
|
env->GetFloatField(rule3, fid);
|
|
|
|
//---------- tokens ----------
|
|
|
|
fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
|
|
jstring s = (jstring)env->GetObjectField(config, fid);
|
|
const char *p = env->GetStringUTFChars(s, nullptr);
|
|
ans.tokens = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
//---------- model config ----------
|
|
fid = env->GetFieldID(cls, "modelConfig",
|
|
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
|
jobject model_config = env->GetObjectField(config, fid);
|
|
jclass model_config_cls = env->GetObjectClass(model_config);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(model_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model_config.encoder_filename = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(model_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model_config.decoder_filename = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(model_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model_config.joiner_filename = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
|
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
|
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
|
|
|
return ans;
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
|
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
|
#if __ANDROID_API__ >= 9
|
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
|
if (!mgr) {
|
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
|
}
|
|
#endif
|
|
|
|
auto config = sherpa_onnx::GetConfig(env, _config);
|
|
auto model = new sherpa_onnx::SherpaOnnx(
|
|
#if __ANDROID_API__ >= 9
|
|
mgr,
|
|
#endif
|
|
config);
|
|
|
|
return (jlong)model;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
SHERPA_ONNX_LOGE("freed!");
|
|
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
|
model->Reset();
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
|
return model->IsEndpoint();
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
|
jfloat sample_rate) {
|
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
|
jsize n = env->GetArrayLength(samples);
|
|
|
|
model->DecodeSamples(sample_rate, p, n);
|
|
|
|
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
// see
|
|
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
|
|
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
|
|
return env->NewStringUTF(text.c_str());
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jfloatArray JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
|
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
|
|
jfloat expected_sample_rate) {
|
|
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
|
#if __ANDROID_API__ >= 9
|
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
|
if (!mgr) {
|
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
|
return nullptr;
|
|
}
|
|
|
|
AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
|
|
size_t asset_length = AAsset_getLength(asset);
|
|
std::vector<char> buffer(asset_length);
|
|
AAsset_read(asset, buffer.data(), asset_length);
|
|
|
|
std::istrstream is(buffer.data(), asset_length);
|
|
#else
|
|
std::ifstream is(p_filename, std::ios::binary);
|
|
#endif
|
|
|
|
bool is_ok = false;
|
|
std::vector<float> samples =
|
|
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
|
|
|
|
#if __ANDROID_API__ >= 9
|
|
AAsset_close(asset);
|
|
#endif
|
|
env->ReleaseStringUTFChars(filename, p_filename);
|
|
|
|
if (!is_ok) {
|
|
return nullptr;
|
|
}
|
|
|
|
jfloatArray ans = env->NewFloatArray(samples.size());
|
|
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
|
|
return ans;
|
|
}
|