220 lines
7.9 KiB
C++
220 lines
7.9 KiB
C++
// sherpa-onnx/jni/offline-speaker-diarization.cc
|
|
//
|
|
// Copyright (c) 2024 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
|
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
#include "sherpa-onnx/jni/common.h"
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig(
|
|
JNIEnv *env, jobject config) {
|
|
OfflineSpeakerDiarizationConfig ans;
|
|
|
|
jclass cls = env->GetObjectClass(config);
|
|
jfieldID fid;
|
|
|
|
//---------- segmentation ----------
|
|
fid = env->GetFieldID(
|
|
cls, "segmentation",
|
|
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;");
|
|
jobject segmentation_config = env->GetObjectField(config, fid);
|
|
jclass segmentation_config_cls = env->GetObjectClass(segmentation_config);
|
|
|
|
fid = env->GetFieldID(
|
|
segmentation_config_cls, "pyannote",
|
|
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;");
|
|
jobject pyannote_config = env->GetObjectField(segmentation_config, fid);
|
|
jclass pyannote_config_cls = env->GetObjectClass(pyannote_config);
|
|
|
|
fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;");
|
|
jstring s = (jstring)env->GetObjectField(pyannote_config, fid);
|
|
const char *p = env->GetStringUTFChars(s, nullptr);
|
|
ans.segmentation.pyannote.model = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I");
|
|
ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid);
|
|
|
|
fid = env->GetFieldID(segmentation_config_cls, "debug", "Z");
|
|
ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid);
|
|
|
|
fid = env->GetFieldID(segmentation_config_cls, "provider",
|
|
"Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(segmentation_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.segmentation.provider = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
//---------- embedding ----------
|
|
fid = env->GetFieldID(
|
|
cls, "embedding",
|
|
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;");
|
|
jobject embedding_config = env->GetObjectField(config, fid);
|
|
jclass embedding_config_cls = env->GetObjectClass(embedding_config);
|
|
|
|
fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(embedding_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.embedding.model = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(embedding_config_cls, "numThreads", "I");
|
|
ans.embedding.num_threads = env->GetIntField(embedding_config, fid);
|
|
|
|
fid = env->GetFieldID(embedding_config_cls, "debug", "Z");
|
|
ans.embedding.debug = env->GetBooleanField(embedding_config, fid);
|
|
|
|
fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(embedding_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.embedding.provider = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
//---------- clustering ----------
|
|
fid = env->GetFieldID(cls, "clustering",
|
|
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;");
|
|
jobject clustering_config = env->GetObjectField(config, fid);
|
|
jclass clustering_config_cls = env->GetObjectClass(clustering_config);
|
|
|
|
fid = env->GetFieldID(clustering_config_cls, "numClusters", "I");
|
|
ans.clustering.num_clusters = env->GetIntField(clustering_config, fid);
|
|
|
|
fid = env->GetFieldID(clustering_config_cls, "threshold", "F");
|
|
ans.clustering.threshold = env->GetFloatField(clustering_config, fid);
|
|
|
|
// its own fields
|
|
fid = env->GetFieldID(cls, "minDurationOn", "F");
|
|
ans.min_duration_on = env->GetFloatField(config, fid);
|
|
|
|
fid = env->GetFieldID(cls, "minDurationOff", "F");
|
|
ans.min_duration_off = env->GetFloatField(config, fid);
|
|
|
|
return ans;
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
|
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
|
return 0;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile(
|
|
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
|
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
|
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
|
|
|
if (!config.Validate()) {
|
|
SHERPA_ONNX_LOGE("Errors found in config!");
|
|
return 0;
|
|
}
|
|
|
|
auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config);
|
|
|
|
return (jlong)sd;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
|
|
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
|
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
|
|
|
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
|
sd->SetConfig(config);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/,
|
|
jobject /*obj*/,
|
|
jlong ptr) {
|
|
delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
|
}
|
|
|
|
static jobjectArray ProcessImpl(
|
|
JNIEnv *env,
|
|
const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment>
|
|
&segments) {
|
|
jclass cls =
|
|
env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment");
|
|
|
|
jobjectArray obj_arr =
|
|
(jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr);
|
|
|
|
jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V");
|
|
|
|
for (int32_t i = 0; i != segments.size(); ++i) {
|
|
const auto &s = segments[i];
|
|
jobject segment =
|
|
env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker());
|
|
env->SetObjectArrayElement(obj_arr, i, segment);
|
|
}
|
|
|
|
return obj_arr;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jobjectArray JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
|
|
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
|
jsize n = env->GetArrayLength(samples);
|
|
auto segments = sd->Process(p, n).SortByStartTime();
|
|
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
|
|
|
return ProcessImpl(env, segments);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jobjectArray JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
|
jobject callback, jlong arg) {
|
|
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
|
|
[env, callback](int32_t num_processed_chunks, int32_t num_total_chunks,
|
|
void *data) -> int {
|
|
jclass cls = env->GetObjectClass(callback);
|
|
|
|
jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;");
|
|
if (mid == nullptr) {
|
|
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
|
|
return 0;
|
|
}
|
|
|
|
jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks,
|
|
num_total_chunks, (jlong)data);
|
|
jclass jklass = env->GetObjectClass(ret);
|
|
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
|
|
return env->CallIntMethod(ret, int_value_mid);
|
|
};
|
|
|
|
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
|
jsize n = env->GetArrayLength(samples);
|
|
auto segments =
|
|
sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime();
|
|
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
|
|
|
return ProcessImpl(env, segments);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jint JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate(
|
|
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
|
|
return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr)
|
|
->SampleRate();
|
|
}
|