diff --git a/.github/workflows/run-java-test.yaml b/.github/workflows/run-java-test.yaml index fbc1d671..487cd8c7 100644 --- a/.github/workflows/run-java-test.yaml +++ b/.github/workflows/run-java-test.yaml @@ -106,6 +106,15 @@ jobs: make -j4 ls -lh lib + - name: Run java test (speaker identification) + shell: bash + run: | + cd ./java-api-examples + ./run-speaker-identification.sh + # Delete model files to save space + rm -rf *.onnx + rm -rf sr-data + - name: Run java test (audio tagging) shell: bash run: | diff --git a/java-api-examples/README.md b/java-api-examples/README.md index 89ca5f08..e775994f 100755 --- a/java-api-examples/README.md +++ b/java-api-examples/README.md @@ -50,3 +50,9 @@ The punctuation model supports both English and Chinese. ./run-audio-tagging-zipformer-from-file.sh ./run-audio-tagging-ced-from-file.sh ``` + +## Speaker identification + +```bash +./run-speaker-identification.sh +``` diff --git a/java-api-examples/SpeakerIdentification.java b/java-api-examples/SpeakerIdentification.java new file mode 100644 index 00000000..971dc296 --- /dev/null +++ b/java-api-examples/SpeakerIdentification.java @@ -0,0 +1,132 @@ +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use a speaker embedding extractor model for speaker +// identification. +import com.k2fsa.sherpa.onnx.*; + +public class SpeakerIdentification { + public static float[] computeEmbedding(SpeakerEmbeddingExtractor extractor, String filename) { + WaveReader reader = new WaveReader(filename); + + OnlineStream stream = extractor.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + stream.inputFinished(); + + float[] embedding = extractor.compute(stream); + stream.release(); + + return embedding; + } + + public static void main(String[] args) { + // Please download the model from + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models + String model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; + SpeakerEmbeddingExtractorConfig config = + SpeakerEmbeddingExtractorConfig.builder() + .setModel(model) + .setNumThreads(1) + .setDebug(true) + .build(); + SpeakerEmbeddingExtractor extractor = new SpeakerEmbeddingExtractor(config); + SpeakerEmbeddingManager manager = new SpeakerEmbeddingManager(extractor.getDim()); + + String[] spk1Files = + new String[] { + "./sr-data/enroll/fangjun-sr-1.wav", + "./sr-data/enroll/fangjun-sr-2.wav", + "./sr-data/enroll/fangjun-sr-3.wav", + }; + + float[][] spk1Vec = new float[spk1Files.length][]; + + for (int i = 0; i < spk1Files.length; ++i) { + spk1Vec[i] = computeEmbedding(extractor, spk1Files[i]); + } + + String[] spk2Files = + new String[] { + "./sr-data/enroll/leijun-sr-1.wav", "./sr-data/enroll/leijun-sr-2.wav", + }; + + float[][] spk2Vec = new float[spk2Files.length][]; + + for (int i = 0; i < spk2Files.length; ++i) { + spk2Vec[i] = computeEmbedding(extractor, spk2Files[i]); + } + + if (!manager.add("fangjun", spk1Vec)) { + System.out.println("Failed to register fangjun"); + return; + } + + if (!manager.add("leijun", spk2Vec)) { + System.out.println("Failed to register leijun"); + return; + } + + if (manager.getNumSpeakers() != 2) { + System.out.println("There should be two speakers"); + return; + } + + if (!manager.contains("fangjun")) { + System.out.println("It should contain the speaker fangjun"); + return; + } + + if (!manager.contains("leijun")) { + System.out.println("It should contain the speaker leijun"); + return; + } + + System.out.println("---All speakers---"); + String[] allSpeakers = manager.getAllSpeakerNames(); + for (String s : allSpeakers) { + System.out.println(s); + } + System.out.println("------------"); + + String[] testFiles = + new String[] { + "./sr-data/test/fangjun-test-sr-1.wav", + "./sr-data/test/leijun-test-sr-1.wav", + "./sr-data/test/liudehua-test-sr-1.wav" + }; + + float threshold = 0.6f; + for (String file : testFiles) { + float[] embedding = computeEmbedding(extractor, file); + + String name = manager.search(embedding, threshold); + if (name.isEmpty()) { + name = ""; + } + System.out.printf("%s: %s\n", file, name); + } + + // test verify + if (!manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) { + System.out.printf("testFiles[0] should match fangjun!"); + return; + } + + if (!manager.remove("fangjun")) { + System.out.println("Failed to remove fangjun"); + return; + } + + if (manager.verify("fangjun", computeEmbedding(extractor, testFiles[0]), threshold)) { + System.out.printf("%s should match no one!\n", testFiles[0]); + return; + } + + if (manager.getNumSpeakers() != 1) { + System.out.println("There should only 1 speaker left."); + return; + } + + extractor.release(); + manager.release(); + } +} diff --git a/java-api-examples/run-speaker-identification.sh b/java-api-examples/run-speaker-identification.sh new file mode 100755 index 00000000..0bfd46e2 --- /dev/null +++ b/java-api-examples/run-speaker-identification.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +fi + +if [ ! -f ./sr-data/enroll/leijun-sr-1.wav ]; then + curl -SL -o sr-data.tar.gz https://github.com/csukuangfj/sr-data/archive/refs/tags/v1.0.0.tar.gz + tar xvf sr-data.tar.gz + mv sr-data-1.0.0 sr-data +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + ./SpeakerIdentification.java diff --git a/sherpa-onnx/java-api/Makefile b/sherpa-onnx/java-api/Makefile index 0d377adb..2b419a45 100644 --- a/sherpa-onnx/java-api/Makefile +++ b/sherpa-onnx/java-api/Makefile @@ -51,6 +51,10 @@ java_files += AudioTaggingConfig.java java_files += AudioEvent.java java_files += AudioTagging.java +java_files += SpeakerEmbeddingExtractorConfig.java +java_files += SpeakerEmbeddingExtractor.java +java_files += SpeakerEmbeddingManager.java + class_files := $(java_files:%.java=%.class) java_files := $(addprefix src/$(package_dir)/,$(java_files)) diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/AudioTaggingConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/AudioTaggingConfig.java index 5c6b5009..bcd286fe 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/AudioTaggingConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/AudioTaggingConfig.java @@ -14,7 +14,7 @@ public class AudioTaggingConfig { } public static Builder builder() { - return new AudioTaggingConfig.Builder(); + return new Builder(); } public static class Builder { diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineRecognizer.java index 8a511457..aa865fe3 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineRecognizer.java @@ -7,7 +7,7 @@ public class OfflineRecognizer { System.loadLibrary("sherpa-onnx-jni"); } - private long ptr = 0; // this is the asr engine ptrss + private long ptr = 0; public OfflineRecognizer(OfflineRecognizerConfig config) { ptr = newFromFile(config); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineTts.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineTts.java index caf93bd8..7762692b 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineTts.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineTts.java @@ -7,7 +7,7 @@ public class OfflineTts { System.loadLibrary("sherpa-onnx-jni"); } - private long ptr = 0; // this is the asr engine ptrss + private long ptr = 0; public OfflineTts(OfflineTtsConfig config) { ptr = newFromFile(config); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 181b8b3e..f2ce97a0 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -8,7 +8,7 @@ public class OnlineRecognizer { System.loadLibrary("sherpa-onnx-jni"); } - private long ptr = 0; // this is the asr engine ptrss + private long ptr = 0; public OnlineRecognizer(OnlineRecognizerConfig config) { diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractor.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractor.java new file mode 100644 index 00000000..5f872bdd --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractor.java @@ -0,0 +1,57 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class SpeakerEmbeddingExtractor { + static { + System.loadLibrary("sherpa-onnx-jni"); + } + + private long ptr = 0; + + public SpeakerEmbeddingExtractor(SpeakerEmbeddingExtractorConfig config) { + ptr = newFromFile(config); + } + + @Override + protected void finalize() throws Throwable { + release(); + } + + public void release() { + if (this.ptr == 0) { + return; + } + delete(this.ptr); + this.ptr = 0; + } + + public OnlineStream createStream() { + long p = createStream(ptr); + return new OnlineStream(p); + } + + public boolean isReady(OnlineStream s) { + return isReady(ptr, s.getPtr()); + } + + public float[] compute(OnlineStream s) { + return compute(ptr, s.getPtr()); + } + + public int getDim() { + return dim(ptr); + } + + private native void delete(long ptr); + + private native long newFromFile(SpeakerEmbeddingExtractorConfig config); + + private native long createStream(long ptr); + + private native boolean isReady(long ptr, long streamPtr); + + private native float[] compute(long ptr, long streamPtr); + + private native int dim(long ptr); +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig.java new file mode 100644 index 00000000..ffc688f3 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig.java @@ -0,0 +1,54 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class SpeakerEmbeddingExtractorConfig { + private final String model; + private final int numThreads; + private final boolean debug; + private final String provider; + + private SpeakerEmbeddingExtractorConfig(Builder builder) { + this.model = builder.model; + this.numThreads = builder.numThreads; + this.debug = builder.debug; + this.provider = builder.provider; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String model = ""; + private int numThreads = 1; + private boolean debug = true; + private String provider = "cpu"; + + public SpeakerEmbeddingExtractorConfig build() { + return new SpeakerEmbeddingExtractorConfig(this); + } + + + public Builder setModel(String model) { + this.model = model; + return this; + } + + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public Builder setDebug(boolean debug) { + this.debug = debug; + return this; + } + + public Builder setProvider(String provider) { + this.provider = provider; + return this; + } + } + +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingManager.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingManager.java new file mode 100644 index 00000000..f4af8d18 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpeakerEmbeddingManager.java @@ -0,0 +1,80 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class SpeakerEmbeddingManager { + static { + System.loadLibrary("sherpa-onnx-jni"); + } + + private long ptr = 0; + + public SpeakerEmbeddingManager(int dim) { + ptr = create(dim); + } + + @Override + protected void finalize() throws Throwable { + release(); + } + + public void release() { + if (this.ptr == 0) { + return; + } + delete(this.ptr); + this.ptr = 0; + } + + public boolean add(String name, float[] embedding) { + return add(ptr, name, embedding); + } + + public boolean add(String name, float[][] embedding) { + return addList(ptr, name, embedding); + } + + public boolean remove(String name) { + return remove(ptr, name); + } + + public String search(float[] embedding, float threshold) { + return search(ptr, embedding, threshold); + } + + public boolean verify(String name, float[] embedding, float threshold) { + return verify(ptr, name, embedding, threshold); + } + + public boolean contains(String name) { + return contains(ptr, name); + } + + public int getNumSpeakers() { + return numSpeakers(ptr); + } + + public String[] getAllSpeakerNames() { + return allSpeakerNames(ptr); + } + + private native long create(int dim); + + private native void delete(long ptr); + + private native boolean add(long ptr, String name, float[] embedding); + + private native boolean addList(long ptr, String name, float[][] embedding); + + private native boolean remove(long ptr, String name); + + private native String search(long ptr, float[] embedding, float threshold); + + private native boolean verify(long ptr, String name, float[] embedding, float threshold); + + private native boolean contains(long ptr, String name); + + private native int numSpeakers(long ptr); + + private native String[] allSpeakerNames(long ptr); +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpokenLanguageIdentification.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpokenLanguageIdentification.java index 379f3853..337e3961 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpokenLanguageIdentification.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/SpokenLanguageIdentification.java @@ -12,7 +12,7 @@ public class SpokenLanguageIdentification { } private final Map localeMap; - private long ptr = 0; // this is the asr engine ptrss + private long ptr = 0; public SpokenLanguageIdentification(SpokenLanguageIdentificationConfig config) { ptr = newFromFile(config);