// sherpa-onnx/jni/jni.cc // // Copyright (c) 2022-2023 Xiaomi Corporation // 2022 Pingfeng Luo // TODO(fangjun): Add documentation to functions/methods in this file // and also show how to use them with kotlin, possibly with java. // If you use ndk, you can find "jni.h" inside // android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include #include "jni.h" // NOLINT #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #else #include #endif #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/wave-reader.h" #define SHERPA_ONNX_EXTERN_C extern "C" namespace sherpa_onnx { class SherpaOnnx { public: SherpaOnnx( #if __ANDROID_API__ >= 9 AAssetManager *mgr, #endif const sherpa_onnx::OnlineRecognizerConfig &config) : recognizer_( #if __ANDROID_API__ >= 9 mgr, #endif config), stream_(recognizer_.CreateStream()) { } void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) const { stream_->AcceptWaveform(sample_rate, samples, n); } void InputFinished() const { std::vector tail_padding(16000 * 0.32, 0); stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size()); stream_->InputFinished(); } const std::string GetText() const { auto result = recognizer_.GetResult(stream_.get()); return result.text; } bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } bool IsReady() const { return recognizer_.IsReady(stream_.get()); } void Reset() const { return recognizer_.Reset(stream_.get()); } void Decode() const { recognizer_.DecodeStream(stream_.get()); } private: sherpa_onnx::OnlineRecognizer recognizer_; std::unique_ptr stream_; }; static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { OnlineRecognizerConfig ans; jclass cls = env->GetObjectClass(config); jfieldID fid; // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html //---------- decoding ---------- fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); jstring s = (jstring)env->GetObjectField(config, fid); const char *p = env->GetStringUTFChars(s, nullptr); ans.decoding_method = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(cls, "maxActivePaths", "I"); ans.max_active_paths = env->GetIntField(config, fid); //---------- feat config ---------- fid = env->GetFieldID(cls, "featConfig", "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); jobject feat_config = env->GetObjectField(config, fid); jclass feat_config_cls = env->GetObjectClass(feat_config); fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); //---------- enable endpoint ---------- fid = env->GetFieldID(cls, "enableEndpoint", "Z"); ans.enable_endpoint = env->GetBooleanField(config, fid); //---------- endpoint_config ---------- fid = env->GetFieldID(cls, "endpointConfig", "Lcom/k2fsa/sherpa/onnx/EndpointConfig;"); jobject endpoint_config = env->GetObjectField(config, fid); jclass endpoint_config_cls = env->GetObjectClass(endpoint_config); fid = env->GetFieldID(endpoint_config_cls, "rule1", "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); jobject rule1 = env->GetObjectField(endpoint_config, fid); jclass rule_class = env->GetObjectClass(rule1); fid = env->GetFieldID(endpoint_config_cls, "rule2", "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); jobject rule2 = env->GetObjectField(endpoint_config, fid); fid = env->GetFieldID(endpoint_config_cls, "rule3", "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); jobject rule3 = env->GetObjectField(endpoint_config, fid); fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z"); ans.endpoint_config.rule1.must_contain_nonsilence = env->GetBooleanField(rule1, fid); ans.endpoint_config.rule2.must_contain_nonsilence = env->GetBooleanField(rule2, fid); ans.endpoint_config.rule3.must_contain_nonsilence = env->GetBooleanField(rule3, fid); fid = env->GetFieldID(rule_class, "minTrailingSilence", "F"); ans.endpoint_config.rule1.min_trailing_silence = env->GetFloatField(rule1, fid); ans.endpoint_config.rule2.min_trailing_silence = env->GetFloatField(rule2, fid); ans.endpoint_config.rule3.min_trailing_silence = env->GetFloatField(rule3, fid); fid = env->GetFieldID(rule_class, "minUtteranceLength", "F"); ans.endpoint_config.rule1.min_utterance_length = env->GetFloatField(rule1, fid); ans.endpoint_config.rule2.min_utterance_length = env->GetFloatField(rule2, fid); ans.endpoint_config.rule3.min_utterance_length = env->GetFloatField(rule3, fid); //---------- model config ---------- fid = env->GetFieldID(cls, "modelConfig", "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); jobject model_config = env->GetObjectField(config, fid); jclass model_config_cls = env->GetObjectClass(model_config); fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.encoder_filename = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.decoder_filename = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.joiner_filename = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.tokens = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "numThreads", "I"); ans.model_config.num_threads = env->GetIntField(model_config, fid); fid = env->GetFieldID(model_config_cls, "debug", "Z"); ans.model_config.debug = env->GetBooleanField(model_config, fid); return ans; } } // namespace sherpa_onnx SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); } #endif auto config = sherpa_onnx::GetConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); auto model = new sherpa_onnx::SherpaOnnx( #if __ANDROID_API__ >= 9 mgr, #endif config); return (jlong)model; } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete( JNIEnv *env, jobject /*obj*/, jlong ptr) { delete reinterpret_cast(ptr); } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( JNIEnv *env, jobject /*obj*/, jlong ptr) { auto model = reinterpret_cast(ptr); model->Reset(); } SHERPA_ONNX_EXTERN_C JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady( JNIEnv *env, jobject /*obj*/, jlong ptr) { auto model = reinterpret_cast(ptr); return model->IsReady(); } SHERPA_ONNX_EXTERN_C JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( JNIEnv *env, jobject /*obj*/, jlong ptr) { auto model = reinterpret_cast(ptr); return model->IsEndpoint(); } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode( JNIEnv *env, jobject /*obj*/, jlong ptr) { auto model = reinterpret_cast(ptr); model->Decode(); } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform( JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, jint sample_rate) { auto model = reinterpret_cast(ptr); jfloat *p = env->GetFloatArrayElements(samples, nullptr); jsize n = env->GetArrayLength(samples); model->AcceptWaveform(sample_rate, p, n); env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished( JNIEnv *env, jobject /*obj*/, jlong ptr) { reinterpret_cast(ptr)->InputFinished(); } SHERPA_ONNX_EXTERN_C JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( JNIEnv *env, jobject /*obj*/, jlong ptr) { // see // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni auto text = reinterpret_cast(ptr)->GetText(); return env->NewStringUTF(text.c_str()); } SHERPA_ONNX_EXTERN_C JNIEXPORT jfloatArray JNICALL Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename, jfloat expected_sample_rate) { const char *p_filename = env->GetStringUTFChars(filename, nullptr); #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); return nullptr; } std::vector buffer = sherpa_onnx::ReadFile(mgr, p_filename); std::istrstream is(buffer.data(), buffer.size()); #else std::ifstream is(p_filename, std::ios::binary); #endif bool is_ok = false; std::vector samples = sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok); env->ReleaseStringUTFChars(filename, p_filename); if (!is_ok) { return nullptr; } jfloatArray ans = env->NewFloatArray(samples.size()); env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data()); return ans; }