From 5860e45b4c1a38fc0f28b273cf3060169b326eaf Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 May 2024 15:49:54 +0800 Subject: [PATCH] Add KWS examples for Java API (#930) --- .github/workflows/run-java-test.yaml | 7 ++ java-api-examples/KeywordSpotterFromFile.java | 65 ++++++++++++++++ java-api-examples/README.md | 6 ++ .../VadNonStreamingParaformer.java | 4 + java-api-examples/VadRemoveSilence.java | 2 + java-api-examples/run-kws-from-file.sh | 37 +++++++++ sherpa-onnx/java-api/Makefile | 4 + .../com/k2fsa/sherpa/onnx/KeywordSpotter.java | 66 ++++++++++++++++ .../sherpa/onnx/KeywordSpotterConfig.java | 77 +++++++++++++++++++ .../sherpa/onnx/KeywordSpotterResult.java | 27 +++++++ .../k2fsa/sherpa/onnx/OnlineRecognizer.java | 2 - 11 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 java-api-examples/KeywordSpotterFromFile.java create mode 100755 java-api-examples/run-kws-from-file.sh create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterConfig.java create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterResult.java diff --git a/.github/workflows/run-java-test.yaml b/.github/workflows/run-java-test.yaml index 93117d03..c000b277 100644 --- a/.github/workflows/run-java-test.yaml +++ b/.github/workflows/run-java-test.yaml @@ -107,6 +107,13 @@ jobs: make -j4 ls -lh lib + - name: Run java test (kws) + shell: bash + run: | + cd ./java-api-examples + ./run-kws-from-file.sh + rm -rf sherpa-onnx-* + - name: Run java test (VAD + Non-streaming Paraformer) shell: bash run: | diff --git a/java-api-examples/KeywordSpotterFromFile.java b/java-api-examples/KeywordSpotterFromFile.java new file mode 100644 index 00000000..1b7a739a --- /dev/null +++ b/java-api-examples/KeywordSpotterFromFile.java @@ -0,0 +1,65 @@ +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use a keyword spotter model to spot keywords from +// a file. + +import com.k2fsa.sherpa.onnx.*; + +public class KyewordSpotterFromFile { + public static void main(String[] args) { + // please download test files from https://github.com/k2-fsa/sherpa-onnx/releases/tag/kws-models + String encoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + String decoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + String joiner = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; + String tokens = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; + + String keywordsFile = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"; + + String waveFilename = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"; + + OnlineTransducerModelConfig transducer = + OnlineTransducerModelConfig.builder() + .setEncoder(encoder) + .setDecoder(decoder) + .setJoiner(joiner) + .build(); + + OnlineModelConfig modelConfig = + OnlineModelConfig.builder() + .setTransducer(transducer) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + KeywordSpotterConfig config = + KeywordSpotterConfig.builder() + .setOnlineModelConfig(modelConfig) + .setKeywordsFile(keywordsFile) + .build(); + + KeywordSpotter kws = new KeywordSpotter(config); + OnlineStream stream = kws.createStream(); + + WaveReader reader = new WaveReader(waveFilename); + + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + float[] tailPaddings = new float[(int) (0.8 * reader.getSampleRate())]; + stream.acceptWaveform(tailPaddings, reader.getSampleRate()); + while (kws.isReady(stream)) { + kws.decode(stream); + + String keyword = kws.getResult(stream).getKeyword(); + if (!keyword.isEmpty()) { + System.out.printf("Detected keyword: %s\n", keyword); + } + } + + kws.release(); + } +} diff --git a/java-api-examples/README.md b/java-api-examples/README.md index 96973e15..18f53fae 100755 --- a/java-api-examples/README.md +++ b/java-api-examples/README.md @@ -68,3 +68,9 @@ The punctuation model supports both English and Chinese. ```bash ./run-vad-non-streaming-paraformer.sh ``` + +## Keyword spotter + +```bash +./run-kws-from-file.sh +``` diff --git a/java-api-examples/VadNonStreamingParaformer.java b/java-api-examples/VadNonStreamingParaformer.java index be54d2d2..48e446ae 100644 --- a/java-api-examples/VadNonStreamingParaformer.java +++ b/java-api-examples/VadNonStreamingParaformer.java @@ -91,6 +91,7 @@ public class VadNonStreamingParaformer { stream.acceptWaveform(segment.getSamples(), 16000); recognizer.decode(stream); String text = recognizer.getResult(stream).getText(); + stream.release(); if (!text.isEmpty()) { System.out.printf("%.3f--%.3f: %s\n", startTime, startTime + duration, text); @@ -100,5 +101,8 @@ public class VadNonStreamingParaformer { } } } + + vad.release(); + recognizer.release(); } } diff --git a/java-api-examples/VadRemoveSilence.java b/java-api-examples/VadRemoveSilence.java index 2d5e48d9..4ee40d0d 100644 --- a/java-api-examples/VadRemoveSilence.java +++ b/java-api-examples/VadRemoveSilence.java @@ -75,5 +75,7 @@ public class VadRemoveSilence { String outFilename = "lei-jun-test-no-silence.wav"; WaveWriter.write(outFilename, allSamples, 16000); System.out.printf("Saved to %s\n", outFilename); + + vad.release(); } } diff --git a/java-api-examples/run-kws-from-file.sh b/java-api-examples/run-kws-from-file.sh new file mode 100755 index 00000000..0a60dcb0 --- /dev/null +++ b/java-api-examples/run-kws-from-file.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [ ! -f ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + ./KeywordSpotterFromFile.java diff --git a/sherpa-onnx/java-api/Makefile b/sherpa-onnx/java-api/Makefile index 18bbbcba..4fabf6a4 100644 --- a/sherpa-onnx/java-api/Makefile +++ b/sherpa-onnx/java-api/Makefile @@ -62,6 +62,10 @@ java_files += VadModelConfig.java java_files += SpeechSegment.java java_files += Vad.java +java_files += KeywordSpotterConfig.java +java_files += KeywordSpotterResult.java +java_files += KeywordSpotter.java + class_files := $(java_files:%.java=%.class) java_files := $(addprefix src/$(package_dir)/,$(java_files)) diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java new file mode 100644 index 00000000..a1b897b0 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java @@ -0,0 +1,66 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class KeywordSpotter { + static { + System.loadLibrary("sherpa-onnx-jni"); + } + + private long ptr = 0; + + public KeywordSpotter(KeywordSpotterConfig config) { + ptr = newFromFile(config); + } + + public OnlineStream createStream(String keywords) { + long p = createStream(ptr, keywords); + return new OnlineStream(p); + } + + public OnlineStream createStream() { + long p = createStream(ptr, ""); + return new OnlineStream(p); + } + + public void decode(OnlineStream s) { + decode(ptr, s.getPtr()); + } + + public boolean isReady(OnlineStream s) { + return isReady(ptr, s.getPtr()); + } + + public KeywordSpotterResult getResult(OnlineStream s) { + Object[] arr = getResult(ptr, s.getPtr()); + String keyword = (String) arr[0]; + String[] tokens = (String[]) arr[1]; + float[] timestamps = (float[]) arr[2]; + return new KeywordSpotterResult(keyword, tokens, timestamps); + } + + protected void finalize() throws Throwable { + release(); + } + + // You'd better call it manually if it is not used anymore + public void release() { + if (this.ptr == 0) { + return; + } + delete(this.ptr); + this.ptr = 0; + } + + private native long newFromFile(KeywordSpotterConfig config); + + private native void delete(long ptr); + + private native long createStream(long ptr, String keywords); + + private native void decode(long ptr, long streamPtr); + + private native boolean isReady(long ptr, long streamPtr); + + private native Object[] getResult(long ptr, long streamPtr); +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterConfig.java new file mode 100644 index 00000000..9b617713 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterConfig.java @@ -0,0 +1,77 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class KeywordSpotterConfig { + private final FeatureConfig featConfig; + private final OnlineModelConfig modelConfig; + + private final int maxActivePaths; + private final String keywordsFile; + private final float keywordsScore; + private final float keywordsThreshold; + private final int numTrailingBlanks; + + private KeywordSpotterConfig(Builder builder) { + this.featConfig = builder.featConfig; + this.modelConfig = builder.modelConfig; + this.maxActivePaths = builder.maxActivePaths; + this.keywordsFile = builder.keywordsFile; + this.keywordsScore = builder.keywordsScore; + this.keywordsThreshold = builder.keywordsThreshold; + this.numTrailingBlanks = builder.numTrailingBlanks; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private FeatureConfig featConfig = FeatureConfig.builder().build(); + private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build(); + private int maxActivePaths = 4; + private String keywordsFile = "keywords.txt"; + private float keywordsScore = 1.5f; + private float keywordsThreshold = 0.25f; + private int numTrailingBlanks = 2; + + public KeywordSpotterConfig build() { + return new KeywordSpotterConfig(this); + } + + public Builder setFeatureConfig(FeatureConfig featConfig) { + this.featConfig = featConfig; + return this; + } + + public Builder setOnlineModelConfig(OnlineModelConfig modelConfig) { + this.modelConfig = modelConfig; + return this; + } + + public Builder setMaxActivePaths(int maxActivePaths) { + this.maxActivePaths = maxActivePaths; + return this; + } + + public Builder setKeywordsFile(String keywordsFile) { + this.keywordsFile = keywordsFile; + return this; + } + + public Builder setKeywordsScore(float keywordsScore) { + this.keywordsScore = keywordsScore; + return this; + } + + public Builder setKeywordsThreshold(float keywordsThreshold) { + this.keywordsThreshold = keywordsThreshold; + return this; + } + + public Builder setNumTrailingBlanks(int numTrailingBlanks) { + this.numTrailingBlanks = numTrailingBlanks; + return this; + } + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterResult.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterResult.java new file mode 100644 index 00000000..0106759d --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotterResult.java @@ -0,0 +1,27 @@ +// Copyright 2024 Xiaomi Corporation + +package com.k2fsa.sherpa.onnx; + +public class KeywordSpotterResult { + private final String keyword; + private final String[] tokens; + private final float[] timestamps; + + public KeywordSpotterResult(String keyword, String[] tokens, float[] timestamps) { + this.keyword = keyword; + this.tokens = tokens; + this.timestamps = timestamps; + } + + public String getKeyword() { + return keyword; + } + + public String[] getTokens() { + return tokens; + } + + public float[] getTimestamps() { + return timestamps; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index f2ce97a0..c98d15bb 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -10,7 +10,6 @@ public class OnlineRecognizer { private long ptr = 0; - public OnlineRecognizer(OnlineRecognizerConfig config) { ptr = newFromFile(config); } @@ -19,7 +18,6 @@ public class OnlineRecognizer { decode(ptr, s.getPtr()); } - public boolean isReady(OnlineStream s) { return isReady(ptr, s.getPtr()); }