Refactor Java API (#806)
This commit is contained in:
4
sherpa-onnx/java-api/.gitignore
vendored
4
sherpa-onnx/java-api/.gitignore
vendored
@@ -1,2 +1,6 @@
|
||||
.idea
|
||||
java-api.iml
|
||||
out
|
||||
META-INF
|
||||
build
|
||||
*.jar
|
||||
|
||||
42
sherpa-onnx/java-api/Makefile
Normal file
42
sherpa-onnx/java-api/Makefile
Normal file
@@ -0,0 +1,42 @@
|
||||
|
||||
# all .class and .jar files are put inside out_dir
|
||||
out_dir := build
|
||||
out_jar := $(out_dir)/sherpa-onnx.jar
|
||||
|
||||
package_dir := com/k2fsa/sherpa/onnx
|
||||
|
||||
java_files := WaveReader.java
|
||||
java_files += EndpointRule.java
|
||||
java_files += EndpointConfig.java
|
||||
java_files += FeatureConfig.java
|
||||
java_files += OnlineLMConfig.java
|
||||
java_files += OnlineParaformerModelConfig.java
|
||||
java_files += OnlineZipformer2CtcModelConfig.java
|
||||
java_files += OnlineTransducerModelConfig.java
|
||||
java_files += OnlineModelConfig.java
|
||||
java_files += OnlineStream.java
|
||||
java_files += OnlineRecognizerConfig.java
|
||||
java_files += OnlineRecognizerResult.java
|
||||
java_files += OnlineRecognizer.java
|
||||
|
||||
class_files := $(java_files:%.java=%.class)
|
||||
|
||||
java_files := $(addprefix src/$(package_dir)/,$(java_files))
|
||||
class_files := $(addprefix $(out_dir)/$(package_dir)/,$(class_files))
|
||||
|
||||
$(info -- java files $(java_files))
|
||||
$(info --)
|
||||
$(info -- class files $(class_files))
|
||||
|
||||
.phony: all clean
|
||||
|
||||
all: $(out_jar)
|
||||
|
||||
$(out_jar): $(class_files)
|
||||
jar --create --verbose --file $(out_jar) -C $(out_dir) .
|
||||
|
||||
clean:
|
||||
$(RM) -rfv $(out_dir)
|
||||
|
||||
$(class_files): $(out_dir)/$(package_dir)/%.class: src/$(package_dir)/%.java
|
||||
javac -d $(out_dir) --class-path $(out_dir) $<
|
||||
@@ -1,18 +1,22 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class EndpointConfig {
|
||||
|
||||
private final EndpointRule rule1;
|
||||
private final EndpointRule rule2;
|
||||
private final EndpointRule rule3;
|
||||
|
||||
public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) {
|
||||
this.rule1 = rule1;
|
||||
this.rule2 = rule2;
|
||||
this.rule3 = rule3;
|
||||
private EndpointConfig(Builder builder) {
|
||||
this.rule1 = builder.rule1;
|
||||
this.rule2 = builder.rule2;
|
||||
this.rule3 = builder.rule3;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public EndpointRule getRule1() {
|
||||
@@ -26,4 +30,42 @@ public class EndpointConfig {
|
||||
public EndpointRule getRule3() {
|
||||
return rule3;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private EndpointRule rule1 = EndpointRule.builder().
|
||||
setMustContainNonSilence(false).
|
||||
setMinTrailingSilence(2.4f).
|
||||
setMinUtteranceLength(0).
|
||||
build();
|
||||
private EndpointRule rule2 = EndpointRule.builder().
|
||||
setMustContainNonSilence(true).
|
||||
setMinTrailingSilence(1.4f).
|
||||
setMinUtteranceLength(0).
|
||||
build();
|
||||
private EndpointRule rule3 = EndpointRule.builder().
|
||||
setMustContainNonSilence(false).
|
||||
setMinTrailingSilence(0.0f).
|
||||
setMinUtteranceLength(20.0f).
|
||||
build();
|
||||
|
||||
public EndpointConfig build() {
|
||||
return new EndpointConfig(this);
|
||||
}
|
||||
|
||||
public Builder setRule1(EndpointRule rule) {
|
||||
this.rule1 = rule;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setRule2(EndpointRule rule) {
|
||||
this.rule2 = rule;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setRul3(EndpointRule rule) {
|
||||
this.rule3 = rule;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class EndpointRule {
|
||||
|
||||
private final boolean mustContainNonSilence;
|
||||
private final float minTrailingSilence;
|
||||
private final float minUtteranceLength;
|
||||
|
||||
public EndpointRule(
|
||||
boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) {
|
||||
this.mustContainNonSilence = mustContainNonSilence;
|
||||
this.minTrailingSilence = minTrailingSilence;
|
||||
this.minUtteranceLength = minUtteranceLength;
|
||||
private EndpointRule(Builder builder) {
|
||||
this.mustContainNonSilence = builder.mustContainNonSilence;
|
||||
this.minTrailingSilence = builder.minTrailingSilence;
|
||||
this.minUtteranceLength = builder.minUtteranceLength;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public float getMinTrailingSilence() {
|
||||
@@ -27,4 +29,29 @@ public class EndpointRule {
|
||||
public boolean getMustContainNonSilence() {
|
||||
return mustContainNonSilence;
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private boolean mustContainNonSilence = false;
|
||||
private float minTrailingSilence = 0;
|
||||
private float minUtteranceLength = 0;
|
||||
|
||||
public EndpointRule build() {
|
||||
return new EndpointRule(this);
|
||||
}
|
||||
|
||||
public Builder setMustContainNonSilence(boolean mustContainNonSilence) {
|
||||
this.mustContainNonSilence = mustContainNonSilence;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMinTrailingSilence(float minTrailingSilence) {
|
||||
this.minTrailingSilence = minTrailingSilence;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMinUtteranceLength(float minUtteranceLength) {
|
||||
this.minUtteranceLength = minUtteranceLength;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
@@ -8,9 +7,13 @@ public class FeatureConfig {
|
||||
private final int sampleRate;
|
||||
private final int featureDim;
|
||||
|
||||
public FeatureConfig(int sampleRate, int featureDim) {
|
||||
this.sampleRate = sampleRate;
|
||||
this.featureDim = featureDim;
|
||||
private FeatureConfig(Builder builder) {
|
||||
this.sampleRate = builder.sampleRate;
|
||||
this.featureDim = builder.featureDim;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public int getSampleRate() {
|
||||
@@ -20,4 +23,23 @@ public class FeatureConfig {
|
||||
public int getFeatureDim() {
|
||||
return featureDim;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private int sampleRate = 16000;
|
||||
private int featureDim = 80;
|
||||
|
||||
public FeatureConfig build() {
|
||||
return new FeatureConfig(this);
|
||||
}
|
||||
|
||||
public Builder setSampleRate(int sampleRate) {
|
||||
this.sampleRate = sampleRate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureDim(int featureDim) {
|
||||
this.featureDim = featureDim;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineLMConfig {
|
||||
|
||||
private final String model;
|
||||
private final float scale;
|
||||
|
||||
public OnlineLMConfig(String model, float scale) {
|
||||
this.model = model;
|
||||
this.scale = scale;
|
||||
private OnlineLMConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
this.scale = builder.scale;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
@@ -20,4 +24,23 @@ public class OnlineLMConfig {
|
||||
public float getScale() {
|
||||
return scale;
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String model = "";
|
||||
private float scale = 1.0f;
|
||||
|
||||
public OnlineLMConfig build() {
|
||||
return new OnlineLMConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setScale(float scale) {
|
||||
this.scale = scale;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,36 +1,30 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineModelConfig {
|
||||
private final OnlineParaformerModelConfig paraformer;
|
||||
private final OnlineTransducerModelConfig transducer;
|
||||
private final OnlineParaformerModelConfig paraformer;
|
||||
private final OnlineZipformer2CtcModelConfig zipformer2Ctc;
|
||||
private final String tokens;
|
||||
private final int numThreads;
|
||||
private final boolean debug;
|
||||
private final String provider = "cpu";
|
||||
private String modelType = "";
|
||||
private final String provider;
|
||||
private final String modelType;
|
||||
private OnlineModelConfig(Builder builder) {
|
||||
this.transducer = builder.transducer;
|
||||
this.paraformer = builder.paraformer;
|
||||
this.zipformer2Ctc = builder.zipformer2Ctc;
|
||||
this.tokens = builder.tokens;
|
||||
this.numThreads = builder.numThreads;
|
||||
this.debug = builder.debug;
|
||||
this.provider = builder.provider;
|
||||
this.modelType = builder.modelType;
|
||||
}
|
||||
|
||||
public OnlineModelConfig(
|
||||
String tokens,
|
||||
int numThreads,
|
||||
boolean debug,
|
||||
String modelType,
|
||||
OnlineParaformerModelConfig paraformer,
|
||||
OnlineTransducerModelConfig transducer,
|
||||
OnlineZipformer2CtcModelConfig zipformer2Ctc
|
||||
) {
|
||||
|
||||
this.tokens = tokens;
|
||||
this.numThreads = numThreads;
|
||||
this.debug = debug;
|
||||
this.modelType = modelType;
|
||||
this.paraformer = paraformer;
|
||||
this.transducer = transducer;
|
||||
this.zipformer2Ctc = zipformer2Ctc;
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public OnlineParaformerModelConfig getParaformer() {
|
||||
@@ -41,6 +35,10 @@ public class OnlineModelConfig {
|
||||
return transducer;
|
||||
}
|
||||
|
||||
public OnlineZipformer2CtcModelConfig getZipformer2Ctc() {
|
||||
return zipformer2Ctc;
|
||||
}
|
||||
|
||||
public String getTokens() {
|
||||
return tokens;
|
||||
}
|
||||
@@ -52,4 +50,67 @@ public class OnlineModelConfig {
|
||||
public boolean getDebug() {
|
||||
return debug;
|
||||
}
|
||||
|
||||
public String getProvider() {
|
||||
return provider;
|
||||
}
|
||||
|
||||
public String getModelType() {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private OnlineParaformerModelConfig paraformer = OnlineParaformerModelConfig.builder().build();
|
||||
private OnlineTransducerModelConfig transducer = OnlineTransducerModelConfig.builder().build();
|
||||
private OnlineZipformer2CtcModelConfig zipformer2Ctc = OnlineZipformer2CtcModelConfig.builder().build();
|
||||
private String tokens = "";
|
||||
private int numThreads = 1;
|
||||
private boolean debug = true;
|
||||
private String provider = "cpu";
|
||||
private String modelType = "";
|
||||
|
||||
public OnlineModelConfig build() {
|
||||
return new OnlineModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setTransducer(OnlineTransducerModelConfig transducer) {
|
||||
this.transducer = transducer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setParaformer(OnlineParaformerModelConfig paraformer) {
|
||||
this.paraformer = paraformer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setZipformer2Ctc(OnlineZipformer2CtcModelConfig zipformer2Ctc) {
|
||||
this.zipformer2Ctc = zipformer2Ctc;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTokens(String tokens) {
|
||||
this.tokens = tokens;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumThreads(int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDebug(boolean debug) {
|
||||
this.debug = debug;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setProvider(String provider) {
|
||||
this.provider = provider;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModelType(String modelType) {
|
||||
this.modelType = modelType;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
@@ -8,9 +7,13 @@ public class OnlineParaformerModelConfig {
|
||||
private final String encoder;
|
||||
private final String decoder;
|
||||
|
||||
public OnlineParaformerModelConfig(String encoder, String decoder) {
|
||||
this.encoder = encoder;
|
||||
this.decoder = decoder;
|
||||
private OnlineParaformerModelConfig(Builder builder) {
|
||||
this.encoder = builder.encoder;
|
||||
this.decoder = builder.decoder;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getEncoder() {
|
||||
@@ -20,4 +23,23 @@ public class OnlineParaformerModelConfig {
|
||||
public String getDecoder() {
|
||||
return decoder;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String encoder = "";
|
||||
private String decoder = "";
|
||||
|
||||
public OnlineParaformerModelConfig build() {
|
||||
return new OnlineParaformerModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setEncoder(String encoder) {
|
||||
this.encoder = encoder;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDecoder(String decoder) {
|
||||
this.decoder = decoder;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,234 +1,21 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
* // the online recognizer for sherpa-onnx, it can load config from a file
|
||||
* // or by argument
|
||||
*/
|
||||
/*
|
||||
usage example:
|
||||
|
||||
String cfgpath=appdir+"/modelconfig.cfg";
|
||||
OnlineRecognizer.setSoPath(soPath); //set so lib path
|
||||
|
||||
OnlineRecognizer rcgOjb = new OnlineRecognizer(); //create a recognizer
|
||||
rcgOjb = new OnlineRecognizer(cfgFile); //set model config file
|
||||
CreateStream streamObj=rcgOjb.CreateStream(); //create a stream for read wav data
|
||||
float[] buffer = rcgOjb.readWavFile(wavfilename); // read data from file
|
||||
streamObj.acceptWaveform(buffer); // feed stream with data
|
||||
streamObj.inputFinished(); // tell engine you done with all data
|
||||
OnlineStream ssObj[] = new OnlineStream[1];
|
||||
while (rcgOjb.isReady(streamObj)) { // engine is ready for unprocessed data
|
||||
ssObj[0] = streamObj;
|
||||
rcgOjb.decodeStreams(ssObj); // decode for multiple stream
|
||||
// rcgOjb.DecodeStream(streamObj); // decode for single stream
|
||||
}
|
||||
|
||||
String recText = "simple:" + rcgOjb.getResult(streamObj) + "\n";
|
||||
byte[] utf8Data = recText.getBytes(StandardCharsets.UTF_8);
|
||||
System.out.println(new String(utf8Data));
|
||||
rcgOjb.reSet(streamObj);
|
||||
rcgOjb.releaseStream(streamObj); // release stream
|
||||
rcgOjb.release(); // release recognizer
|
||||
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.util.Enumeration;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
public class OnlineRecognizer {
|
||||
static {
|
||||
System.loadLibrary("sherpa-onnx-jni");
|
||||
}
|
||||
|
||||
private long ptr = 0; // this is the asr engine ptrss
|
||||
|
||||
private int sampleRate = 16000;
|
||||
|
||||
// load config file for OnlineRecognizer
|
||||
public OnlineRecognizer(String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg, zipformer2CtcConfig);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
|
||||
} catch (Exception e) {
|
||||
System.err.println(e);
|
||||
}
|
||||
}
|
||||
|
||||
// use for android asset_manager ANDROID_API__ >= 9
|
||||
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
|
||||
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg, zipformer2CtcConfig);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
|
||||
|
||||
} catch (Exception e) {
|
||||
System.err.println(e);
|
||||
}
|
||||
}
|
||||
|
||||
// set onlineRecognizer by parameter
|
||||
public OnlineRecognizer(
|
||||
String tokens,
|
||||
String encoder,
|
||||
String decoder,
|
||||
String joiner,
|
||||
int numThreads,
|
||||
int sampleRate,
|
||||
int featureDim,
|
||||
boolean enableEndpointDetection,
|
||||
float rule1MinTrailingSilence,
|
||||
float rule2MinTrailingSilence,
|
||||
float rule3MinUtteranceLength,
|
||||
String decodingMethod,
|
||||
String lm_model,
|
||||
float lm_scale,
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore,
|
||||
String modelType) {
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
||||
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(encoder, decoder, joiner);
|
||||
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig);
|
||||
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
||||
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
enableEndpointDetection,
|
||||
decodingMethod,
|
||||
maxActivePaths,
|
||||
hotwordsFile,
|
||||
hotwordsScore);
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
public OnlineRecognizer(OnlineRecognizerConfig config) {
|
||||
ptr = newFromFile(config);
|
||||
}
|
||||
|
||||
/*
|
||||
public static float[] readWavFile(String fileName) {
|
||||
// read data from the filename
|
||||
Object[] wavdata = readWave(fileName);
|
||||
@@ -238,139 +25,67 @@ public class OnlineRecognizer {
|
||||
|
||||
return floatData;
|
||||
}
|
||||
*/
|
||||
|
||||
// load the libsherpa-onnx-jni.so lib
|
||||
public static void loadSoLib(String soPath) {
|
||||
// load libsherpa-onnx-jni.so lib from the path
|
||||
|
||||
System.out.println("so lib path=" + soPath + "\n");
|
||||
System.load(soPath.trim());
|
||||
System.out.println("load so lib succeed\n");
|
||||
public void decode(OnlineStream s) {
|
||||
decode(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
public static void setSoPath(String soPath) {
|
||||
OnlineRecognizer.loadSoLib(soPath);
|
||||
OnlineStream.loadSoLib(soPath);
|
||||
|
||||
public boolean isReady(OnlineStream s) {
|
||||
return isReady(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
private static native Object[] readWave(String fileName); // static
|
||||
|
||||
private Map<String, String> readProperties(String modelCfgPath) {
|
||||
// read and parse config file
|
||||
Properties props = new Properties();
|
||||
Map<String, String> proMap = new HashMap<>();
|
||||
try {
|
||||
File file = new File(modelCfgPath);
|
||||
if (!file.exists()) {
|
||||
System.out.println("model cfg file not exists!");
|
||||
System.exit(0);
|
||||
}
|
||||
InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath));
|
||||
props.load(in);
|
||||
Enumeration en = props.propertyNames();
|
||||
while (en.hasMoreElements()) {
|
||||
String key = (String) en.nextElement();
|
||||
String Property = props.getProperty(key);
|
||||
proMap.put(key, Property);
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return proMap;
|
||||
public boolean isEndpoint(OnlineStream s) {
|
||||
return isEndpoint(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
public void decodeStream(OnlineStream s) throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = s.getPtr();
|
||||
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
|
||||
// when feeded samples to engine, call DecodeStream to let it process
|
||||
decodeStream(this.ptr, streamPtr);
|
||||
public void reset(OnlineStream s) {
|
||||
reset(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
public void decodeStreams(OnlineStream[] ssOjb) throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
// decode for multiple streams
|
||||
long[] ss = new long[ssOjb.length];
|
||||
for (int i = 0; i < ssOjb.length; i++) {
|
||||
ss[i] = ssOjb[i].getPtr();
|
||||
if (ss[i] == 0) throw new Exception("null exception for stream ptr");
|
||||
}
|
||||
decodeStreams(this.ptr, ss);
|
||||
}
|
||||
|
||||
public boolean isReady(OnlineStream s) throws Exception {
|
||||
// whether the engine is ready for decode
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = s.getPtr();
|
||||
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
|
||||
return isReady(this.ptr, streamPtr);
|
||||
}
|
||||
|
||||
public String getResult(OnlineStream s) throws Exception {
|
||||
// get text from the engine
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = s.getPtr();
|
||||
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
|
||||
return getResult(this.ptr, streamPtr);
|
||||
}
|
||||
|
||||
public boolean isEndpoint(OnlineStream s) throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = s.getPtr();
|
||||
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
|
||||
return isEndpoint(this.ptr, streamPtr);
|
||||
}
|
||||
|
||||
public void reSet(OnlineStream s) throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = s.getPtr();
|
||||
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
|
||||
reSet(this.ptr, streamPtr);
|
||||
}
|
||||
|
||||
public OnlineStream createStream() throws Exception {
|
||||
// create one stream for data to feed in
|
||||
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
|
||||
long streamPtr = createStream(this.ptr);
|
||||
OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate);
|
||||
return stream;
|
||||
public OnlineStream createStream() {
|
||||
long p = createStream(ptr, "");
|
||||
return new OnlineStream(p);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
release();
|
||||
}
|
||||
|
||||
// recognizer release, you'd better call it manually if not use anymore
|
||||
public void release() {
|
||||
if (this.ptr == 0) return;
|
||||
deleteOnlineRecognizer(this.ptr);
|
||||
if (this.ptr == 0) {
|
||||
return;
|
||||
}
|
||||
delete(this.ptr);
|
||||
this.ptr = 0;
|
||||
}
|
||||
|
||||
// JNI interface libsherpa-onnx-jni.so
|
||||
|
||||
// stream release, you'd better call it manually if not use anymore
|
||||
public void releaseStream(OnlineStream s) {
|
||||
s.release();
|
||||
public OnlineRecognizerResult getResult(OnlineStream s) {
|
||||
Object[] arr = getResult(ptr, s.getPtr());
|
||||
String text = (String) arr[0];
|
||||
String[] tokens = (String[]) arr[1];
|
||||
float[] timestamps = (float[]) arr[2];
|
||||
return new OnlineRecognizerResult(text, tokens, timestamps);
|
||||
}
|
||||
|
||||
private native String getResult(long ptr, long streamPtr);
|
||||
|
||||
private native void decodeStream(long ptr, long streamPtr);
|
||||
private native void delete(long ptr);
|
||||
|
||||
private native void decodeStreams(long ptr, long[] ssPtr);
|
||||
private native long newFromFile(OnlineRecognizerConfig config);
|
||||
|
||||
private native boolean isReady(long ptr, long streamPtr);
|
||||
private native long createStream(long ptr, String hotwords);
|
||||
|
||||
// first parameter keep for android asset_manager ANDROID_API__ >= 9
|
||||
private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config);
|
||||
private native void reset(long ptr, long streamPtr);
|
||||
|
||||
private native long createStream(long ptr);
|
||||
|
||||
private native void deleteOnlineRecognizer(long ptr);
|
||||
private native void decode(long ptr, long streamPtr);
|
||||
|
||||
private native boolean isEndpoint(long ptr, long streamPtr);
|
||||
|
||||
private native void reSet(long ptr, long streamPtr);
|
||||
}
|
||||
private native boolean isReady(long ptr, long streamPtr);
|
||||
|
||||
private native Object[] getResult(long ptr, long streamPtr);
|
||||
}
|
||||
@@ -1,66 +1,95 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineRecognizerConfig {
|
||||
private final FeatureConfig featConfig;
|
||||
private final OnlineModelConfig modelConfig;
|
||||
private final EndpointConfig endpointConfig;
|
||||
private final OnlineLMConfig lmConfig;
|
||||
private final EndpointConfig endpointConfig;
|
||||
private final boolean enableEndpoint;
|
||||
private final String decodingMethod;
|
||||
private final int maxActivePaths;
|
||||
private final String hotwordsFile;
|
||||
private final float hotwordsScore;
|
||||
|
||||
public OnlineRecognizerConfig(
|
||||
FeatureConfig featConfig,
|
||||
OnlineModelConfig modelConfig,
|
||||
EndpointConfig endpointConfig,
|
||||
OnlineLMConfig lmConfig,
|
||||
boolean enableEndpoint,
|
||||
String decodingMethod,
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore) {
|
||||
this.featConfig = featConfig;
|
||||
this.modelConfig = modelConfig;
|
||||
this.endpointConfig = endpointConfig;
|
||||
this.lmConfig = lmConfig;
|
||||
this.enableEndpoint = enableEndpoint;
|
||||
this.decodingMethod = decodingMethod;
|
||||
this.maxActivePaths = maxActivePaths;
|
||||
this.hotwordsFile = hotwordsFile;
|
||||
this.hotwordsScore = hotwordsScore;
|
||||
private OnlineRecognizerConfig(Builder builder) {
|
||||
this.featConfig = builder.featConfig;
|
||||
this.modelConfig = builder.modelConfig;
|
||||
this.lmConfig = builder.lmConfig;
|
||||
this.endpointConfig = builder.endpointConfig;
|
||||
this.enableEndpoint = builder.enableEndpoint;
|
||||
this.decodingMethod = builder.decodingMethod;
|
||||
this.maxActivePaths = builder.maxActivePaths;
|
||||
this.hotwordsFile = builder.hotwordsFile;
|
||||
this.hotwordsScore = builder.hotwordsScore;
|
||||
}
|
||||
|
||||
public OnlineLMConfig getLmConfig() {
|
||||
return lmConfig;
|
||||
}
|
||||
|
||||
public FeatureConfig getFeatConfig() {
|
||||
return featConfig;
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public OnlineModelConfig getModelConfig() {
|
||||
return modelConfig;
|
||||
}
|
||||
|
||||
public EndpointConfig getEndpointConfig() {
|
||||
return endpointConfig;
|
||||
}
|
||||
public static class Builder {
|
||||
private FeatureConfig featConfig = FeatureConfig.builder().build();
|
||||
private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
|
||||
private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build();
|
||||
private EndpointConfig endpointConfig = EndpointConfig.builder().build();
|
||||
private boolean enableEndpoint = true;
|
||||
private String decodingMethod = "greedy_search";
|
||||
private int maxActivePaths = 4;
|
||||
private String hotwordsFile = "";
|
||||
private float hotwordsScore = 1.5f;
|
||||
|
||||
public boolean isEnableEndpoint() {
|
||||
return enableEndpoint;
|
||||
}
|
||||
public OnlineRecognizerConfig build() {
|
||||
return new OnlineRecognizerConfig(this);
|
||||
}
|
||||
|
||||
public String getDecodingMethod() {
|
||||
return decodingMethod;
|
||||
}
|
||||
public Builder setFeatureConfig(FeatureConfig featConfig) {
|
||||
this.featConfig = featConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getMaxActivePaths() {
|
||||
return maxActivePaths;
|
||||
public Builder setOnlineModelConfig(OnlineModelConfig modelConfig) {
|
||||
this.modelConfig = modelConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setOnlineLMConfig(OnlineLMConfig lmConfig) {
|
||||
this.lmConfig = lmConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEndpointConfig(EndpointConfig endpointConfig) {
|
||||
this.endpointConfig = endpointConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEnableEndpoint(boolean enableEndpoint) {
|
||||
this.enableEndpoint = enableEndpoint;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDecodingMethod(String decodingMethod) {
|
||||
this.decodingMethod = decodingMethod;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMaxActivePaths(int maxActivePaths) {
|
||||
this.maxActivePaths = maxActivePaths;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHotwordsFile(String hotwordsFile) {
|
||||
this.hotwordsFile = hotwordsFile;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHotwordsScore(float hotwordsScore) {
|
||||
this.hotwordsScore = hotwordsScore;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineRecognizerResult {
|
||||
private final String text;
|
||||
private final String[] tokens;
|
||||
private final float[] timestamps;
|
||||
|
||||
public OnlineRecognizerResult(String text, String[] tokens, float[] timestamps) {
|
||||
this.text = text;
|
||||
this.tokens = tokens;
|
||||
this.timestamps = timestamps;
|
||||
}
|
||||
|
||||
public String getText() {
|
||||
return text;
|
||||
}
|
||||
|
||||
public String[] getTokens() {
|
||||
return tokens;
|
||||
}
|
||||
|
||||
public float[] getTimestamps() {
|
||||
return timestamps;
|
||||
}
|
||||
}
|
||||
@@ -1,84 +1,56 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Stream is used for feeding data to the asr engine
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineStream {
|
||||
private long ptr = 0; // this is the stream ptr
|
||||
|
||||
private int sampleRate = 16000;
|
||||
|
||||
// assign ptr to this stream in construction
|
||||
public OnlineStream(long ptr, int sampleRate) {
|
||||
this.ptr = ptr;
|
||||
this.sampleRate = sampleRate;
|
||||
static {
|
||||
System.loadLibrary("sherpa-onnx-jni");
|
||||
}
|
||||
|
||||
public static void loadSoLib(String soPath) {
|
||||
// load .so lib from the path
|
||||
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
|
||||
private long ptr = 0;
|
||||
|
||||
public OnlineStream() {
|
||||
this.ptr = 0;
|
||||
}
|
||||
|
||||
public OnlineStream(long ptr) {
|
||||
this.ptr = ptr;
|
||||
}
|
||||
|
||||
public long getPtr() {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
public void acceptWaveform(float[] samples) throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
|
||||
public void setPtr(long ptr) {
|
||||
this.ptr = ptr;
|
||||
}
|
||||
|
||||
// feed wave data to asr engine
|
||||
acceptWaveform(this.ptr, this.sampleRate, samples);
|
||||
public void acceptWaveform(float[] samples, int sampleRate) {
|
||||
acceptWaveform(this.ptr, samples, sampleRate);
|
||||
}
|
||||
|
||||
public void inputFinished() {
|
||||
// add some tail padding
|
||||
int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
|
||||
float[] tailPaddings = new float[padLen]; // default value is 0
|
||||
acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
|
||||
|
||||
// tell the engine all data are feeded
|
||||
inputFinished(this.ptr);
|
||||
}
|
||||
|
||||
public void release() {
|
||||
// stream object must be release after used
|
||||
if (this.ptr == 0) return;
|
||||
deleteStream(this.ptr);
|
||||
if (this.ptr == 0) {
|
||||
return;
|
||||
}
|
||||
delete(this.ptr);
|
||||
this.ptr = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
release();
|
||||
super.finalize();
|
||||
}
|
||||
|
||||
public boolean isLastFrame() throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
|
||||
return isLastFrame(this.ptr);
|
||||
}
|
||||
|
||||
public void reSet() throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
|
||||
reSet(this.ptr);
|
||||
}
|
||||
|
||||
public int featureDim() throws Exception {
|
||||
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
|
||||
return featureDim(this.ptr);
|
||||
}
|
||||
|
||||
// JNI interface libsherpa-onnx-jni.so
|
||||
private native void acceptWaveform(long ptr, int sampleRate, float[] samples);
|
||||
private native void acceptWaveform(long ptr, float[] samples, int sampleRate);
|
||||
|
||||
private native void inputFinished(long ptr);
|
||||
|
||||
private native void deleteStream(long ptr);
|
||||
|
||||
private native int numFramesReady(long ptr);
|
||||
|
||||
private native boolean isLastFrame(long ptr);
|
||||
|
||||
private native void reSet(long ptr);
|
||||
|
||||
private native int featureDim(long ptr);
|
||||
}
|
||||
private native void delete(long ptr);
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
/*
|
||||
* // Copyright 2022-2023 by zhaoming
|
||||
*/
|
||||
// Copyright 2022-2023 by zhaoming
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
@@ -9,10 +8,14 @@ public class OnlineTransducerModelConfig {
|
||||
private final String decoder;
|
||||
private final String joiner;
|
||||
|
||||
public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
|
||||
this.encoder = encoder;
|
||||
this.decoder = decoder;
|
||||
this.joiner = joiner;
|
||||
private OnlineTransducerModelConfig(Builder builder) {
|
||||
this.encoder = builder.encoder;
|
||||
this.decoder = builder.decoder;
|
||||
this.joiner = builder.joiner;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getEncoder() {
|
||||
@@ -26,4 +29,29 @@ public class OnlineTransducerModelConfig {
|
||||
public String getJoiner() {
|
||||
return joiner;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String encoder = "";
|
||||
private String decoder = "";
|
||||
private String joiner = "";
|
||||
|
||||
public OnlineTransducerModelConfig build() {
|
||||
return new OnlineTransducerModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setEncoder(String encoder) {
|
||||
this.encoder = encoder;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDecoder(String decoder) {
|
||||
this.decoder = decoder;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setJoiner(String joiner) {
|
||||
this.joiner = joiner;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,31 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlineZipformer2CtcModelConfig {
|
||||
private final String model;
|
||||
|
||||
public OnlineZipformer2CtcModelConfig(String model) {
|
||||
this.model = model;
|
||||
private OnlineZipformer2CtcModelConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String model = "";
|
||||
|
||||
public OnlineZipformer2CtcModelConfig build() {
|
||||
return new OnlineZipformer2CtcModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class WaveReader {
|
||||
static {
|
||||
System.loadLibrary("sherpa-onnx-jni");
|
||||
}
|
||||
|
||||
private final int sampleRate;
|
||||
private final float[] samples;
|
||||
|
||||
// It supports only single channel, 16-bit wave file.
|
||||
// It will exit the program if the given file has a wrong format
|
||||
public WaveReader(String filename) {
|
||||
Object[] arr = readWaveFromFile(filename);
|
||||
samples = (float[]) arr[0];
|
||||
sampleRate = (int) arr[1];
|
||||
}
|
||||
|
||||
public int getSampleRate() {
|
||||
return sampleRate;
|
||||
}
|
||||
|
||||
public float[] getSamples() {
|
||||
return samples;
|
||||
}
|
||||
|
||||
private native Object[] readWaveFromFile(String filename);
|
||||
}
|
||||
Reference in New Issue
Block a user