Add Java API for speaker identification (#822)
This commit is contained in:
9
.github/workflows/run-java-test.yaml
vendored
9
.github/workflows/run-java-test.yaml
vendored
@@ -106,6 +106,15 @@ jobs:
|
|||||||
make -j4
|
make -j4
|
||||||
ls -lh lib
|
ls -lh lib
|
||||||
|
|
||||||
|
- name: Run java test (speaker identification)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd ./java-api-examples
|
||||||
|
./run-speaker-identification.sh
|
||||||
|
# Delete model files to save space
|
||||||
|
rm -rf *.onnx
|
||||||
|
rm -rf sr-data
|
||||||
|
|
||||||
- name: Run java test (audio tagging)
|
- name: Run java test (audio tagging)
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -50,3 +50,9 @@ The punctuation model supports both English and Chinese.
|
|||||||
./run-audio-tagging-zipformer-from-file.sh
|
./run-audio-tagging-zipformer-from-file.sh
|
||||||
./run-audio-tagging-ced-from-file.sh
|
./run-audio-tagging-ced-from-file.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Speaker identification
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./run-speaker-identification.sh
|
||||||
|
```
|
||||||
|
|||||||
132
java-api-examples/SpeakerIdentification.java
Normal file
132
java-api-examples/SpeakerIdentification.java
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
// Copyright 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
// This file shows how to use a speaker embedding extractor model for speaker
|
||||||
|
// identification.
|
||||||
|
import com.k2fsa.sherpa.onnx.*;
|
||||||
|
|
||||||
|
public class SpeakerIdentification {
|
||||||
|
public static float[] computeEmbedding(SpeakerEmbeddingExtractor extractor, String filename) {
|
||||||
|
WaveReader reader = new WaveReader(filename);
|
||||||
|
|
||||||
|
OnlineStream stream = extractor.createStream();
|
||||||
|
stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
|
||||||
|
stream.inputFinished();
|
||||||
|
|
||||||
|
float[] embedding = extractor.compute(stream);
|
||||||
|
stream.release();
|
||||||
|
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// Please download the model from
|
||||||
|
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||||
|
String model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";
|
||||||
|
SpeakerEmbeddingExtractorConfig config =
|
||||||
|
SpeakerEmbeddingExtractorConfig.builder()
|
||||||
|
.setModel(model)
|
||||||
|
.setNumThreads(1)
|
||||||
|
.setDebug(true)
|
||||||
|
.build();
|
||||||
|
SpeakerEmbeddingExtractor extractor = new SpeakerEmbeddingExtractor(config);
|
||||||
|
SpeakerEmbeddingManager manager = new SpeakerEmbeddingManager(extractor.getDim());
|
||||||
|
|
||||||
|
String[] spk1Files =
|
||||||
|
new String[] {
|
||||||
|
"./sr-data/enroll/fangjun-sr-1.wav",
|
||||||
|
"./sr-data/enroll/fangjun-sr-2.wav",
|
||||||
|
"./sr-data/enroll/fangjun-sr-3.wav",
|
||||||
|
};
|
||||||
|
|
||||||
|
float[][] spk1Vec = new float[spk1Files.length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < spk1Files.length; ++i) {
|
||||||
|
spk1Vec[i] = computeEmbedding(extractor, spk1Files[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] spk2Files =
|
||||||
|
new String[] {
|
||||||
|
"./sr-data/enroll/leijun-sr-1.wav", "./sr-data/enroll/leijun-sr-2.wav",
|
||||||
|
};
|
||||||
|
|
||||||
|
float[][] spk2Vec = new float[spk2Files.length][];
|
||||||
|
|
||||||
|
for (int i = 0; i < spk2Files.length; ++i) {
|
||||||
|
spk2Vec[i] = computeEmbedding(extractor, spk2Files[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!manager.add("fangjun", spk1Vec)) {
|
||||||
|
System.out.println("Failed to register fangjun");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!manager.add("leijun", spk2Vec)) {
|
||||||
|
System.out.println("Failed to register leijun");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (manager.getNumSpeakers() != 2) {
|
||||||
|
System.out.println("There should be two speakers");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!manager.contains("fangjun")) {
|
||||||
|
System.out.println("It should contain the speaker fangjun");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!manager.contains("leijun")) {
|
||||||
|
System.out.println("It should contain the speaker leijun");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("---All speakers---");
|
||||||
|
String[] allSpeakers = manager.getAllSpeakerNames();
|
||||||
|
for (String s : allSpeakers) {
|
||||||
|
System.out.println(s);
|
||||||
|
}
|
||||||
|
System.out.println("------------");
|
||||||
|
|
||||||
|
String[] testFiles =
|
||||||
|
new String[] {
|
||||||
|
"./sr-data/test/fangjun-test-sr-1.wav",
|
||||||
|
"./sr-data/test/leijun-test-sr-1.wav",
|
||||||
|
"./sr-data/test/liudehua-test-sr-1.wav"
|
||||||
|
};
|
||||||
|
|
||||||
|
float threshold = 0.6f;
|
||||||
|
for (String file : testFiles) {
|
||||||
|
float[] embedding = computeEmbedding(extractor, file);
|
||||||
|
|
||||||
|
String name = manager.search(embedding, threshold);
|
||||||
|
if (name.isEmpty()) {
|
||||||
|
name = "<Unknown>";
|
||||||
|
}
|
||||||
|
System.out.printf("%s: %s\n", file, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// test verify
|
||||||
|
if (!manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) {
|
||||||
|
System.out.printf("testFiles[0] should match fangjun!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!manager.remove("fangjun")) {
|
||||||
|
System.out.println("Failed to remove fangjun");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) {
|
||||||
|
System.out.printf("%s should match no one!\n", testFiles[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (manager.getNumSpeakers() != 1) {
|
||||||
|
System.out.println("There should only 1 speaker left.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor.release();
|
||||||
|
manager.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
41
java-api-examples/run-speaker-identification.sh
Executable file
41
java-api-examples/run-speaker-identification.sh
Executable file
@@ -0,0 +1,41 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
|
||||||
|
mkdir -p ../build
|
||||||
|
pushd ../build
|
||||||
|
cmake \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_JNI=ON \
|
||||||
|
..
|
||||||
|
|
||||||
|
make -j4
|
||||||
|
ls -lh lib
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
|
||||||
|
pushd ../sherpa-onnx/java-api
|
||||||
|
make
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ./sr-data/enroll/leijun-sr-1.wav ]; then
|
||||||
|
curl -SL -o sr-data.tar.gz https://github.com/csukuangfj/sr-data/archive/refs/tags/v1.0.0.tar.gz
|
||||||
|
tar xvf sr-data.tar.gz
|
||||||
|
mv sr-data-1.0.0 sr-data
|
||||||
|
fi
|
||||||
|
|
||||||
|
java \
|
||||||
|
-Djava.library.path=$PWD/../build/lib \
|
||||||
|
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
|
||||||
|
./SpeakerIdentification.java
|
||||||
@@ -51,6 +51,10 @@ java_files += AudioTaggingConfig.java
|
|||||||
java_files += AudioEvent.java
|
java_files += AudioEvent.java
|
||||||
java_files += AudioTagging.java
|
java_files += AudioTagging.java
|
||||||
|
|
||||||
|
java_files += SpeakerEmbeddingExtractorConfig.java
|
||||||
|
java_files += SpeakerEmbeddingExtractor.java
|
||||||
|
java_files += SpeakerEmbeddingManager.java
|
||||||
|
|
||||||
class_files := $(java_files:%.java=%.class)
|
class_files := $(java_files:%.java=%.class)
|
||||||
|
|
||||||
java_files := $(addprefix src/$(package_dir)/,$(java_files))
|
java_files := $(addprefix src/$(package_dir)/,$(java_files))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ public class AudioTaggingConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AudioTaggingConfig.Builder();
|
return new Builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ public class OfflineRecognizer {
|
|||||||
System.loadLibrary("sherpa-onnx-jni");
|
System.loadLibrary("sherpa-onnx-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
private long ptr = 0; // this is the asr engine ptrss
|
private long ptr = 0;
|
||||||
|
|
||||||
public OfflineRecognizer(OfflineRecognizerConfig config) {
|
public OfflineRecognizer(OfflineRecognizerConfig config) {
|
||||||
ptr = newFromFile(config);
|
ptr = newFromFile(config);
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ public class OfflineTts {
|
|||||||
System.loadLibrary("sherpa-onnx-jni");
|
System.loadLibrary("sherpa-onnx-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
private long ptr = 0; // this is the asr engine ptrss
|
private long ptr = 0;
|
||||||
|
|
||||||
public OfflineTts(OfflineTtsConfig config) {
|
public OfflineTts(OfflineTtsConfig config) {
|
||||||
ptr = newFromFile(config);
|
ptr = newFromFile(config);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ public class OnlineRecognizer {
|
|||||||
System.loadLibrary("sherpa-onnx-jni");
|
System.loadLibrary("sherpa-onnx-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
private long ptr = 0; // this is the asr engine ptrss
|
private long ptr = 0;
|
||||||
|
|
||||||
|
|
||||||
public OnlineRecognizer(OnlineRecognizerConfig config) {
|
public OnlineRecognizer(OnlineRecognizerConfig config) {
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class SpeakerEmbeddingExtractor {
|
||||||
|
static {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni");
|
||||||
|
}
|
||||||
|
|
||||||
|
private long ptr = 0;
|
||||||
|
|
||||||
|
public SpeakerEmbeddingExtractor(SpeakerEmbeddingExtractorConfig config) {
|
||||||
|
ptr = newFromFile(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void finalize() throws Throwable {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void release() {
|
||||||
|
if (this.ptr == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
delete(this.ptr);
|
||||||
|
this.ptr = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnlineStream createStream() {
|
||||||
|
long p = createStream(ptr);
|
||||||
|
return new OnlineStream(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isReady(OnlineStream s) {
|
||||||
|
return isReady(ptr, s.getPtr());
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] compute(OnlineStream s) {
|
||||||
|
return compute(ptr, s.getPtr());
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getDim() {
|
||||||
|
return dim(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private native void delete(long ptr);
|
||||||
|
|
||||||
|
private native long newFromFile(SpeakerEmbeddingExtractorConfig config);
|
||||||
|
|
||||||
|
private native long createStream(long ptr);
|
||||||
|
|
||||||
|
private native boolean isReady(long ptr, long streamPtr);
|
||||||
|
|
||||||
|
private native float[] compute(long ptr, long streamPtr);
|
||||||
|
|
||||||
|
private native int dim(long ptr);
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class SpeakerEmbeddingExtractorConfig {
|
||||||
|
private final String model;
|
||||||
|
private final int numThreads;
|
||||||
|
private final boolean debug;
|
||||||
|
private final String provider;
|
||||||
|
|
||||||
|
private SpeakerEmbeddingExtractorConfig(Builder builder) {
|
||||||
|
this.model = builder.model;
|
||||||
|
this.numThreads = builder.numThreads;
|
||||||
|
this.debug = builder.debug;
|
||||||
|
this.provider = builder.provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private String model = "";
|
||||||
|
private int numThreads = 1;
|
||||||
|
private boolean debug = true;
|
||||||
|
private String provider = "cpu";
|
||||||
|
|
||||||
|
public SpeakerEmbeddingExtractorConfig build() {
|
||||||
|
return new SpeakerEmbeddingExtractorConfig(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public Builder setModel(String model) {
|
||||||
|
this.model = model;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setNumThreads(int numThreads) {
|
||||||
|
this.numThreads = numThreads;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setDebug(boolean debug) {
|
||||||
|
this.debug = debug;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setProvider(String provider) {
|
||||||
|
this.provider = provider;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
// Copyright 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class SpeakerEmbeddingManager {
|
||||||
|
static {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni");
|
||||||
|
}
|
||||||
|
|
||||||
|
private long ptr = 0;
|
||||||
|
|
||||||
|
public SpeakerEmbeddingManager(int dim) {
|
||||||
|
ptr = create(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void finalize() throws Throwable {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void release() {
|
||||||
|
if (this.ptr == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
delete(this.ptr);
|
||||||
|
this.ptr = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean add(String name, float[] embedding) {
|
||||||
|
return add(ptr, name, embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean add(String name, float[][] embedding) {
|
||||||
|
return addList(ptr, name, embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean remove(String name) {
|
||||||
|
return remove(ptr, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String search(float[] embedding, float threshold) {
|
||||||
|
return search(ptr, embedding, threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean verify(String name, float[] embedding, float threshold) {
|
||||||
|
return verify(ptr, name, embedding, threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean contains(String name) {
|
||||||
|
return contains(ptr, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumSpeakers() {
|
||||||
|
return numSpeakers(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String[] getAllSpeakerNames() {
|
||||||
|
return allSpeakerNames(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
private native long create(int dim);
|
||||||
|
|
||||||
|
private native void delete(long ptr);
|
||||||
|
|
||||||
|
private native boolean add(long ptr, String name, float[] embedding);
|
||||||
|
|
||||||
|
private native boolean addList(long ptr, String name, float[][] embedding);
|
||||||
|
|
||||||
|
private native boolean remove(long ptr, String name);
|
||||||
|
|
||||||
|
private native String search(long ptr, float[] embedding, float threshold);
|
||||||
|
|
||||||
|
private native boolean verify(long ptr, String name, float[] embedding, float threshold);
|
||||||
|
|
||||||
|
private native boolean contains(long ptr, String name);
|
||||||
|
|
||||||
|
private native int numSpeakers(long ptr);
|
||||||
|
|
||||||
|
private native String[] allSpeakerNames(long ptr);
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ public class SpokenLanguageIdentification {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private final Map<String, String> localeMap;
|
private final Map<String, String> localeMap;
|
||||||
private long ptr = 0; // this is the asr engine ptrss
|
private long ptr = 0;
|
||||||
|
|
||||||
public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) {
|
public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) {
|
||||||
ptr = newFromFile(config);
|
ptr = newFromFile(config);
|
||||||
|
|||||||
Reference in New Issue
Block a user