// sherpa-onnx/jni/offline-speaker-diarization.cc // // Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-speaker-diarization.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/jni/common.h" namespace sherpa_onnx { static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig( JNIEnv *env, jobject config) { OfflineSpeakerDiarizationConfig ans; jclass cls = env->GetObjectClass(config); jfieldID fid; //---------- segmentation ---------- fid = env->GetFieldID( cls, "segmentation", "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;"); jobject segmentation_config = env->GetObjectField(config, fid); jclass segmentation_config_cls = env->GetObjectClass(segmentation_config); fid = env->GetFieldID( segmentation_config_cls, "pyannote", "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;"); jobject pyannote_config = env->GetObjectField(segmentation_config, fid); jclass pyannote_config_cls = env->GetObjectClass(pyannote_config); fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;"); jstring s = (jstring)env->GetObjectField(pyannote_config, fid); const char *p = env->GetStringUTFChars(s, nullptr); ans.segmentation.pyannote.model = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I"); ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid); fid = env->GetFieldID(segmentation_config_cls, "debug", "Z"); ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid); fid = env->GetFieldID(segmentation_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(segmentation_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.segmentation.provider = p; env->ReleaseStringUTFChars(s, p); //---------- embedding ---------- fid = env->GetFieldID( cls, "embedding", "Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;"); jobject embedding_config = env->GetObjectField(config, fid); jclass embedding_config_cls = env->GetObjectClass(embedding_config); fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(embedding_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.embedding.model = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(embedding_config_cls, "numThreads", "I"); ans.embedding.num_threads = env->GetIntField(embedding_config, fid); fid = env->GetFieldID(embedding_config_cls, "debug", "Z"); ans.embedding.debug = env->GetBooleanField(embedding_config, fid); fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(embedding_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.embedding.provider = p; env->ReleaseStringUTFChars(s, p); //---------- clustering ---------- fid = env->GetFieldID(cls, "clustering", "Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;"); jobject clustering_config = env->GetObjectField(config, fid); jclass clustering_config_cls = env->GetObjectClass(clustering_config); fid = env->GetFieldID(clustering_config_cls, "numClusters", "I"); ans.clustering.num_clusters = env->GetIntField(clustering_config, fid); fid = env->GetFieldID(clustering_config_cls, "threshold", "F"); ans.clustering.threshold = env->GetFloatField(clustering_config, fid); // its own fields fid = env->GetFieldID(cls, "minDurationOn", "F"); ans.min_duration_on = env->GetFloatField(config, fid); fid = env->GetFieldID(cls, "minDurationOff", "F"); ans.min_duration_off = env->GetFloatField(config, fid); return ans; } } // namespace sherpa_onnx SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset( JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); return 0; } #endif auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); auto sd = new sherpa_onnx::OfflineSpeakerDiarization( #if __ANDROID_API__ >= 9 mgr, #endif config); return (jlong)sd; } SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile( JNIEnv *env, jobject /*obj*/, jobject _config) { auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); if (!config.Validate()) { SHERPA_ONNX_LOGE("Errors found in config!"); return 0; } auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config); return (jlong)sd; } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig( JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); auto sd = reinterpret_cast(ptr); sd->SetConfig(config); } SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { delete reinterpret_cast(ptr); } static jobjectArray ProcessImpl( JNIEnv *env, const std::vector &segments) { jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment"); jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr); jmethodID constructor = env->GetMethodID(cls, "", "(FFI)V"); for (int32_t i = 0; i != segments.size(); ++i) { const auto &s = segments[i]; jobject segment = env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker()); env->SetObjectArrayElement(obj_arr, i, segment); } return obj_arr; } SHERPA_ONNX_EXTERN_C JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process( JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { auto sd = reinterpret_cast(ptr); jfloat *p = env->GetFloatArrayElements(samples, nullptr); jsize n = env->GetArrayLength(samples); auto segments = sd->Process(p, n).SortByStartTime(); env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); return ProcessImpl(env, segments); } SHERPA_ONNX_EXTERN_C JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback( JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, jobject callback, jlong arg) { std::function callback_wrapper = [env, callback](int32_t num_processed_chunks, int32_t num_total_chunks, void *data) -> int { jclass cls = env->GetObjectClass(callback); jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;"); if (mid == nullptr) { SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); return 0; } jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks, num_total_chunks, (jlong)data); jclass jklass = env->GetObjectClass(ret); jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); return env->CallIntMethod(ret, int_value_mid); }; auto sd = reinterpret_cast(ptr); jfloat *p = env->GetFloatArrayElements(samples, nullptr); jsize n = env->GetArrayLength(samples); auto segments = sd->Process(p, n, callback_wrapper, reinterpret_cast(arg)) .SortByStartTime(); env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); return ProcessImpl(env, segments); } SHERPA_ONNX_EXTERN_C JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate( JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { return reinterpret_cast(ptr) ->SampleRate(); }