Add KWS examples for Java API (#930)

This commit is contained in:
Fangjun Kuang
2024-05-28 15:49:54 +08:00
committed by GitHub
parent bcaa6df389
commit 5860e45b4c
11 changed files with 295 additions and 2 deletions

View File

@@ -0,0 +1,66 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class KeywordSpotter {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0;
public KeywordSpotter(KeywordSpotterConfig config) {
ptr = newFromFile(config);
}
public OnlineStream createStream(String keywords) {
long p = createStream(ptr, keywords);
return new OnlineStream(p);
}
public OnlineStream createStream() {
long p = createStream(ptr, "");
return new OnlineStream(p);
}
public void decode(OnlineStream s) {
decode(ptr, s.getPtr());
}
public boolean isReady(OnlineStream s) {
return isReady(ptr, s.getPtr());
}
public KeywordSpotterResult getResult(OnlineStream s) {
Object[] arr = getResult(ptr, s.getPtr());
String keyword = (String) arr[0];
String[] tokens = (String[]) arr[1];
float[] timestamps = (float[]) arr[2];
return new KeywordSpotterResult(keyword, tokens, timestamps);
}
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native long newFromFile(KeywordSpotterConfig config);
private native void delete(long ptr);
private native long createStream(long ptr, String keywords);
private native void decode(long ptr, long streamPtr);
private native boolean isReady(long ptr, long streamPtr);
private native Object[] getResult(long ptr, long streamPtr);
}

View File

@@ -0,0 +1,77 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class KeywordSpotterConfig {
private final FeatureConfig featConfig;
private final OnlineModelConfig modelConfig;
private final int maxActivePaths;
private final String keywordsFile;
private final float keywordsScore;
private final float keywordsThreshold;
private final int numTrailingBlanks;
private KeywordSpotterConfig(Builder builder) {
this.featConfig = builder.featConfig;
this.modelConfig = builder.modelConfig;
this.maxActivePaths = builder.maxActivePaths;
this.keywordsFile = builder.keywordsFile;
this.keywordsScore = builder.keywordsScore;
this.keywordsThreshold = builder.keywordsThreshold;
this.numTrailingBlanks = builder.numTrailingBlanks;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private FeatureConfig featConfig = FeatureConfig.builder().build();
private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
private int maxActivePaths = 4;
private String keywordsFile = "keywords.txt";
private float keywordsScore = 1.5f;
private float keywordsThreshold = 0.25f;
private int numTrailingBlanks = 2;
public KeywordSpotterConfig build() {
return new KeywordSpotterConfig(this);
}
public Builder setFeatureConfig(FeatureConfig featConfig) {
this.featConfig = featConfig;
return this;
}
public Builder setOnlineModelConfig(OnlineModelConfig modelConfig) {
this.modelConfig = modelConfig;
return this;
}
public Builder setMaxActivePaths(int maxActivePaths) {
this.maxActivePaths = maxActivePaths;
return this;
}
public Builder setKeywordsFile(String keywordsFile) {
this.keywordsFile = keywordsFile;
return this;
}
public Builder setKeywordsScore(float keywordsScore) {
this.keywordsScore = keywordsScore;
return this;
}
public Builder setKeywordsThreshold(float keywordsThreshold) {
this.keywordsThreshold = keywordsThreshold;
return this;
}
public Builder setNumTrailingBlanks(int numTrailingBlanks) {
this.numTrailingBlanks = numTrailingBlanks;
return this;
}
}
}

View File

@@ -0,0 +1,27 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class KeywordSpotterResult {
private final String keyword;
private final String[] tokens;
private final float[] timestamps;
public KeywordSpotterResult(String keyword, String[] tokens, float[] timestamps) {
this.keyword = keyword;
this.tokens = tokens;
this.timestamps = timestamps;
}
public String getKeyword() {
return keyword;
}
public String[] getTokens() {
return tokens;
}
public float[] getTimestamps() {
return timestamps;
}
}

View File

@@ -10,7 +10,6 @@ public class OnlineRecognizer {
private long ptr = 0;
public OnlineRecognizer(OnlineRecognizerConfig config) {
ptr = newFromFile(config);
}
@@ -19,7 +18,6 @@ public class OnlineRecognizer {
decode(ptr, s.getPtr());
}
public boolean isReady(OnlineStream s) {
return isReady(ptr, s.getPtr());
}