Kotlin API for speaker diarization (#1415)
This commit is contained in:
1
kotlin-api-examples/OfflineSpeakerDiarization.kt
Symbolic link
1
kotlin-api-examples/OfflineSpeakerDiarization.kt
Symbolic link
@@ -0,0 +1 @@
|
||||
../sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
|
||||
@@ -285,6 +285,37 @@ function testPunctuation() {
|
||||
java -Djava.library.path=../build/lib -jar $out_filename
|
||||
}
|
||||
|
||||
function testOfflineSpeakerDiarization() {
|
||||
if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
fi
|
||||
|
||||
if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||
fi
|
||||
|
||||
if [ ! -f ./0-four-speakers-zh.wav ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
||||
fi
|
||||
|
||||
out_filename=test_offline_speaker_diarization.jar
|
||||
kotlinc-jvm -include-runtime -d $out_filename \
|
||||
test_offline_speaker_diarization.kt \
|
||||
OfflineSpeakerDiarization.kt \
|
||||
Speaker.kt \
|
||||
OnlineStream.kt \
|
||||
WaveReader.kt \
|
||||
faked-asset-manager.kt \
|
||||
faked-log.kt
|
||||
|
||||
ls -lh $out_filename
|
||||
|
||||
java -Djava.library.path=../build/lib -jar $out_filename
|
||||
}
|
||||
|
||||
testOfflineSpeakerDiarization
|
||||
testSpeakerEmbeddingExtractor
|
||||
testOnlineAsr
|
||||
testTts
|
||||
|
||||
53
kotlin-api-examples/test_offline_speaker_diarization.kt
Normal file
53
kotlin-api-examples/test_offline_speaker_diarization.kt
Normal file
@@ -0,0 +1,53 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
fun main() {
|
||||
testOfflineSpeakerDiarization()
|
||||
}
|
||||
|
||||
fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int {
|
||||
val progress = numProcessedChunks.toFloat() / numTotalChunks * 100
|
||||
val s = "%.2f".format(progress)
|
||||
println("Progress: ${s}%");
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
fun testOfflineSpeakerDiarization() {
|
||||
var config = OfflineSpeakerDiarizationConfig(
|
||||
segmentation=OfflineSpeakerSegmentationModelConfig(
|
||||
pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"),
|
||||
),
|
||||
embedding=SpeakerEmbeddingExtractorConfig(
|
||||
model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
|
||||
),
|
||||
|
||||
// The test wave file ./0-four-speakers-zh.wav contains four speakers, so
|
||||
// we use numClusters=4 here. If you don't know the number of speakers
|
||||
// in the test wave file, please set the threshold like below.
|
||||
//
|
||||
// clustering=FastClusteringConfig(threshold=0.5),
|
||||
//
|
||||
// WARNING: You need to tune threshold by yourself.
|
||||
// A larger threshold leads to fewer clusters, i.e., few speakers.
|
||||
// A smaller threshold leads to more clusters, i.e., more speakers.
|
||||
//
|
||||
clustering=FastClusteringConfig(numClusters=4),
|
||||
)
|
||||
|
||||
val sd = OfflineSpeakerDiarization(config=config)
|
||||
|
||||
val waveData = WaveReader.readWave(
|
||||
filename = "./0-four-speakers-zh.wav",
|
||||
)
|
||||
|
||||
if (sd.sampleRate() != waveData.sampleRate) {
|
||||
println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}")
|
||||
return
|
||||
}
|
||||
|
||||
// val segments = sd.process(waveData.samples) // this one is also ok
|
||||
val segments = sd.processWithCallback(waveData.samples, callback=::callback)
|
||||
for (segment in segments) {
|
||||
println("${segment.start} -- ${segment.end} speaker_${segment.speaker}")
|
||||
}
|
||||
}
|
||||
@@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult {
|
||||
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
|
||||
const;
|
||||
|
||||
public:
|
||||
private:
|
||||
std::vector<OfflineSpeakerDiarizationSegment> segments_;
|
||||
};
|
||||
|
||||
|
||||
@@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
|
||||
)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND sources
|
||||
offline-speaker-diarization.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(sherpa-onnx-jni ${sources})
|
||||
|
||||
target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1)
|
||||
|
||||
219
sherpa-onnx/jni/offline-speaker-diarization.cc
Normal file
219
sherpa-onnx/jni/offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,219 @@
|
||||
// sherpa-onnx/jni/offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig(
|
||||
JNIEnv *env, jobject config) {
|
||||
OfflineSpeakerDiarizationConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
//---------- segmentation ----------
|
||||
fid = env->GetFieldID(
|
||||
cls, "segmentation",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;");
|
||||
jobject segmentation_config = env->GetObjectField(config, fid);
|
||||
jclass segmentation_config_cls = env->GetObjectClass(segmentation_config);
|
||||
|
||||
fid = env->GetFieldID(
|
||||
segmentation_config_cls, "pyannote",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;");
|
||||
jobject pyannote_config = env->GetObjectField(segmentation_config, fid);
|
||||
jclass pyannote_config_cls = env->GetObjectClass(pyannote_config);
|
||||
|
||||
fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(pyannote_config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.segmentation.pyannote.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I");
|
||||
ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid);
|
||||
|
||||
fid = env->GetFieldID(segmentation_config_cls, "debug", "Z");
|
||||
ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid);
|
||||
|
||||
fid = env->GetFieldID(segmentation_config_cls, "provider",
|
||||
"Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(segmentation_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.segmentation.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- embedding ----------
|
||||
fid = env->GetFieldID(
|
||||
cls, "embedding",
|
||||
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;");
|
||||
jobject embedding_config = env->GetObjectField(config, fid);
|
||||
jclass embedding_config_cls = env->GetObjectClass(embedding_config);
|
||||
|
||||
fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(embedding_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.embedding.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(embedding_config_cls, "numThreads", "I");
|
||||
ans.embedding.num_threads = env->GetIntField(embedding_config, fid);
|
||||
|
||||
fid = env->GetFieldID(embedding_config_cls, "debug", "Z");
|
||||
ans.embedding.debug = env->GetBooleanField(embedding_config, fid);
|
||||
|
||||
fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(embedding_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.embedding.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- clustering ----------
|
||||
fid = env->GetFieldID(cls, "clustering",
|
||||
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;");
|
||||
jobject clustering_config = env->GetObjectField(config, fid);
|
||||
jclass clustering_config_cls = env->GetObjectClass(clustering_config);
|
||||
|
||||
fid = env->GetFieldID(clustering_config_cls, "numClusters", "I");
|
||||
ans.clustering.num_clusters = env->GetIntField(clustering_config, fid);
|
||||
|
||||
fid = env->GetFieldID(clustering_config_cls, "threshold", "F");
|
||||
ans.clustering.threshold = env->GetFloatField(clustering_config, fid);
|
||||
|
||||
// its own fields
|
||||
fid = env->GetFieldID(cls, "minDurationOn", "F");
|
||||
ans.min_duration_on = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "minDurationOff", "F");
|
||||
ans.min_duration_off = env->GetFloatField(config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config);
|
||||
|
||||
return (jlong)sd;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
|
||||
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
||||
sd->SetConfig(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
||||
}
|
||||
|
||||
static jobjectArray ProcessImpl(
|
||||
JNIEnv *env,
|
||||
const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment>
|
||||
&segments) {
|
||||
jclass cls =
|
||||
env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment");
|
||||
|
||||
jobjectArray obj_arr =
|
||||
(jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr);
|
||||
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V");
|
||||
|
||||
for (int32_t i = 0; i != segments.size(); ++i) {
|
||||
const auto &s = segments[i];
|
||||
jobject segment =
|
||||
env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker());
|
||||
env->SetObjectArrayElement(obj_arr, i, segment);
|
||||
}
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
|
||||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
auto segments = sd->Process(p, n).SortByStartTime();
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
return ProcessImpl(env, segments);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jobject callback, jlong arg) {
|
||||
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
|
||||
[env, callback](int32_t num_processed_chunks, int32_t num_total_chunks,
|
||||
void *data) -> int {
|
||||
jclass cls = env->GetObjectClass(callback);
|
||||
|
||||
jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;");
|
||||
if (mid == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks,
|
||||
num_total_chunks, (jlong)data);
|
||||
jclass jklass = env->GetObjectClass(ret);
|
||||
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
|
||||
return env->CallIntMethod(ret, int_value_mid);
|
||||
};
|
||||
|
||||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
auto segments =
|
||||
sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime();
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
return ProcessImpl(env, segments);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate(
|
||||
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
|
||||
return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr)
|
||||
->SampleRate();
|
||||
}
|
||||
101
sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
Normal file
101
sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
Normal file
@@ -0,0 +1,101 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OfflineSpeakerSegmentationPyannoteModelConfig(
|
||||
var model: String,
|
||||
)
|
||||
|
||||
data class OfflineSpeakerSegmentationModelConfig(
|
||||
var pyannote: OfflineSpeakerSegmentationPyannoteModelConfig,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class FastClusteringConfig(
|
||||
var numClusters: Int = -1,
|
||||
var threshold: Float = 0.5f,
|
||||
)
|
||||
|
||||
data class OfflineSpeakerDiarizationConfig(
|
||||
var segmentation: OfflineSpeakerSegmentationModelConfig,
|
||||
var embedding: SpeakerEmbeddingExtractorConfig,
|
||||
var clustering: FastClusteringConfig,
|
||||
var minDurationOn: Float = 0.2f,
|
||||
var minDurationOff: Float = 0.5f,
|
||||
)
|
||||
|
||||
data class OfflineSpeakerDiarizationSegment(
|
||||
val start: Float, // in seconds
|
||||
val end: Float, // in seconds
|
||||
val speaker: Int, // ID of the speaker; count from 0
|
||||
)
|
||||
|
||||
class OfflineSpeakerDiarization(
|
||||
assetManager: AssetManager? = null,
|
||||
config: OfflineSpeakerDiarizationConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
// Only config.clustering is used. All other fields in config
|
||||
// are ignored
|
||||
fun setConfig(config: OfflineSpeakerDiarizationConfig) = setConfig(ptr, config)
|
||||
|
||||
fun sampleRate() = getSampleRate(ptr)
|
||||
|
||||
fun process(samples: FloatArray) = process(ptr, samples)
|
||||
|
||||
fun processWithCallback(
|
||||
samples: FloatArray,
|
||||
callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
|
||||
arg: Long = 0,
|
||||
) = processWithCallback(ptr, samples, callback, arg)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OfflineSpeakerDiarizationConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OfflineSpeakerDiarizationConfig,
|
||||
): Long
|
||||
|
||||
private external fun setConfig(ptr: Long, config: OfflineSpeakerDiarizationConfig)
|
||||
|
||||
private external fun getSampleRate(ptr: Long): Int
|
||||
|
||||
private external fun process(ptr: Long, samples: FloatArray): Array<OfflineSpeakerDiarizationSegment>
|
||||
|
||||
private external fun processWithCallback(
|
||||
ptr: Long,
|
||||
samples: FloatArray,
|
||||
callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
|
||||
arg: Long,
|
||||
): Array<OfflineSpeakerDiarizationSegment>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user