Java API for speaker diarization (#1416)

This commit is contained in:
Fangjun Kuang
2024-10-11 16:51:40 +08:00
committed by GitHub
parent 2d412b1190
commit 1851ff6337
14 changed files with 471 additions and 1 deletions

View File

@@ -0,0 +1,44 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class FastClusteringConfig {
private final int numClusters;
private final float threshold;
private FastClusteringConfig(Builder builder) {
this.numClusters = builder.numClusters;
this.threshold = builder.threshold;
}
public static Builder builder() {
return new Builder();
}
public int getNumClusters() {
return numClusters;
}
public float getThreshold() {
return threshold;
}
public static class Builder {
private int numClusters = -1;
private float threshold = 0.5f;
public FastClusteringConfig build() {
return new FastClusteringConfig(this);
}
public Builder setNumClusters(int numClusters) {
this.numClusters = numClusters;
return this;
}
public Builder setThreshold(float threshold) {
this.threshold = threshold;
return this;
}
}
}

View File

@@ -0,0 +1,61 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineSpeakerDiarization {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0;
public OfflineSpeakerDiarization(OfflineSpeakerDiarizationConfig config) {
ptr = newFromFile(config);
}
public int getSampleRate() {
return getSampleRate(ptr);
}
// Only config.clustering is used. All other fields are ignored
public void setConfig(OfflineSpeakerDiarizationConfig config) {
setConfig(ptr, config);
}
public OfflineSpeakerDiarizationSegment[] process(float[] samples) {
return process(ptr, samples);
}
public OfflineSpeakerDiarizationSegment[] processWithCallback(float[] samples, OfflineSpeakerDiarizationCallback callback) {
return processWithCallback(ptr, samples, callback, 0);
}
public OfflineSpeakerDiarizationSegment[] processWithCallback(float[] samples, OfflineSpeakerDiarizationCallback callback, long arg) {
return processWithCallback(ptr, samples, callback, arg);
}
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native int getSampleRate(long ptr);
private native void delete(long ptr);
private native long newFromFile(OfflineSpeakerDiarizationConfig config);
private native void setConfig(long ptr, OfflineSpeakerDiarizationConfig config);
private native OfflineSpeakerDiarizationSegment[] process(long ptr, float[] samples);
private native OfflineSpeakerDiarizationSegment[] processWithCallback(long ptr, float[] samples, OfflineSpeakerDiarizationCallback callback, long arg);
}

View File

@@ -0,0 +1,8 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
@FunctionalInterface
public interface OfflineSpeakerDiarizationCallback {
Integer invoke(int numProcessedChunks, int numTotalCunks, long arg);
}

View File

@@ -0,0 +1,79 @@
package com.k2fsa.sherpa.onnx;
public class OfflineSpeakerDiarizationConfig {
private final OfflineSpeakerSegmentationModelConfig segmentation;
private final SpeakerEmbeddingExtractorConfig embedding;
private final FastClusteringConfig clustering;
private final float minDurationOn;
private final float minDurationOff;
private OfflineSpeakerDiarizationConfig(Builder builder) {
this.segmentation = builder.segmentation;
this.embedding = builder.embedding;
this.clustering = builder.clustering;
this.minDurationOff = builder.minDurationOff;
this.minDurationOn = builder.minDurationOn;
}
public static Builder builder() {
return new Builder();
}
public OfflineSpeakerSegmentationModelConfig getSegmentation() {
return segmentation;
}
public SpeakerEmbeddingExtractorConfig getEmbedding() {
return embedding;
}
public FastClusteringConfig getClustering() {
return clustering;
}
public float getMinDurationOff() {
return minDurationOff;
}
public float getMinDurationOn() {
return minDurationOn;
}
public static class Builder {
private OfflineSpeakerSegmentationModelConfig segmentation = OfflineSpeakerSegmentationModelConfig.builder().build();
private SpeakerEmbeddingExtractorConfig embedding = SpeakerEmbeddingExtractorConfig.builder().build();
private FastClusteringConfig clustering = FastClusteringConfig.builder().build();
private float minDurationOn = 0.2f;
private float minDurationOff = 0.5f;
public OfflineSpeakerDiarizationConfig build() {
return new OfflineSpeakerDiarizationConfig(this);
}
public Builder setSegmentation(OfflineSpeakerSegmentationModelConfig segmentation) {
this.segmentation = segmentation;
return this;
}
public Builder setEmbedding(SpeakerEmbeddingExtractorConfig embedding) {
this.embedding = embedding;
return this;
}
public Builder setClustering(FastClusteringConfig clustering) {
this.clustering = clustering;
return this;
}
public Builder setMinDurationOff(float minDurationOff) {
this.minDurationOff = minDurationOff;
return this;
}
public Builder setMinDurationOn(float minDurationOn) {
this.minDurationOn = minDurationOn;
return this;
}
}
}

View File

@@ -0,0 +1,27 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineSpeakerDiarizationSegment {
private final float start;
private final float end;
private final int speaker;
public OfflineSpeakerDiarizationSegment(float start, float end, int speaker) {
this.start = start;
this.end = end;
this.speaker = speaker;
}
public float getStart() {
return start;
}
public float getEnd() {
return end;
}
public int getSpeaker() {
return speaker;
}
}

View File

@@ -0,0 +1,52 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineSpeakerSegmentationModelConfig {
private final OfflineSpeakerSegmentationPyannoteModelConfig pyannote;
private final int numThreads;
private final boolean debug;
private final String provider;
private OfflineSpeakerSegmentationModelConfig(Builder builder) {
this.pyannote = builder.pyannote;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private OfflineSpeakerSegmentationPyannoteModelConfig pyannote = OfflineSpeakerSegmentationPyannoteModelConfig.builder().build();
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
public OfflineSpeakerSegmentationModelConfig build() {
return new OfflineSpeakerSegmentationModelConfig(this);
}
public Builder setPyannote(OfflineSpeakerSegmentationPyannoteModelConfig pyannote) {
this.pyannote = pyannote;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Builder setDebug(boolean debug) {
this.debug = debug;
return this;
}
public Builder setProvider(String provider) {
this.provider = provider;
return this;
}
}
}

View File

@@ -0,0 +1,32 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineSpeakerSegmentationPyannoteModelConfig {
private final String model;
private OfflineSpeakerSegmentationPyannoteModelConfig(Builder builder) {
this.model = builder.model;
}
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public static class Builder {
private String model = "";
public OfflineSpeakerSegmentationPyannoteModelConfig build() {
return new OfflineSpeakerSegmentationPyannoteModelConfig(this);
}
public Builder setModel(String model) {
this.model = model;
return this;
}
}
}

View File

@@ -1,3 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
@FunctionalInterface

View File

@@ -50,5 +50,4 @@ public class SpeakerEmbeddingExtractorConfig {
return this;
}
}
}