Fix java tests.
This commit is contained in:
Fangjun Kuang
2024-02-26 13:49:37 +08:00
committed by GitHub
parent ee37d9bd92
commit fb04366179
15 changed files with 561 additions and 532 deletions

2
java-api-examples/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
lib
hs_err*

View File

@@ -9,10 +9,11 @@ LIB_FILES = \
$(LIB_SRC_DIR)/OnlineLMConfig.java \ $(LIB_SRC_DIR)/OnlineLMConfig.java \
$(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
$(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \ $(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \
$(LIB_SRC_DIR)/OnlineZipformer2CtcModelConfig.java \
$(LIB_SRC_DIR)/OnlineModelConfig.java \ $(LIB_SRC_DIR)/OnlineModelConfig.java \
$(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
$(LIB_SRC_DIR)/OnlineStream.java \ $(LIB_SRC_DIR)/OnlineStream.java \
$(LIB_SRC_DIR)/OnlineRecognizer.java \ $(LIB_SRC_DIR)/OnlineRecognizer.java
WEBSOCKET_DIR:= ./src/websocketsrv WEBSOCKET_DIR:= ./src/websocketsrv
WEBSOCKET_FILES = \ WEBSOCKET_FILES = \
@@ -42,10 +43,10 @@ vpath %.java src
buildfile: buildfile:
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE) $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE)
buildmic: buildmic:
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic) $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic)
rebuild: clean all rebuild: clean all
@@ -63,8 +64,8 @@ clean:
mkdir -p $(BUILD_DIR) mkdir -p $(BUILD_DIR)
mkdir -p ./lib mkdir -p ./lib
runfile: runfile: packjar buildfile
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
runhotwords: runhotwords:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
@@ -85,8 +86,7 @@ buildlib: $(LIB_FILES:.java=.class)
%.class: %.java %.class: %.java
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
buildwebsocket: $(WEBSOCKET_FILES:.java=.class) buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
@@ -95,7 +95,7 @@ buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $<
packjar: packjar: buildlib
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket all: clean buildlib packjar buildfile buildmic downjar buildwebsocket

2
sherpa-onnx/java-api/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.idea
java-api.iml

View File

@@ -5,25 +5,25 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class EndpointConfig { public class EndpointConfig {
private final EndpointRule rule1; private final EndpointRule rule1;
private final EndpointRule rule2; private final EndpointRule rule2;
private final EndpointRule rule3; private final EndpointRule rule3;
public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) { public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) {
this.rule1 = rule1; this.rule1 = rule1;
this.rule2 = rule2; this.rule2 = rule2;
this.rule3 = rule3; this.rule3 = rule3;
} }
public EndpointRule getRule1() { public EndpointRule getRule1() {
return rule1; return rule1;
} }
public EndpointRule getRule2() { public EndpointRule getRule2() {
return rule2; return rule2;
} }
public EndpointRule getRule3() { public EndpointRule getRule3() {
return rule3; return rule3;
} }
} }

View File

@@ -5,26 +5,26 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class EndpointRule { public class EndpointRule {
private final boolean mustContainNonSilence; private final boolean mustContainNonSilence;
private final float minTrailingSilence; private final float minTrailingSilence;
private final float minUtteranceLength; private final float minUtteranceLength;
public EndpointRule( public EndpointRule(
boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) { boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) {
this.mustContainNonSilence = mustContainNonSilence; this.mustContainNonSilence = mustContainNonSilence;
this.minTrailingSilence = minTrailingSilence; this.minTrailingSilence = minTrailingSilence;
this.minUtteranceLength = minUtteranceLength; this.minUtteranceLength = minUtteranceLength;
} }
public float getMinTrailingSilence() { public float getMinTrailingSilence() {
return minTrailingSilence; return minTrailingSilence;
} }
public float getMinUtteranceLength() { public float getMinUtteranceLength() {
return minUtteranceLength; return minUtteranceLength;
} }
public boolean getMustContainNonSilence() { public boolean getMustContainNonSilence() {
return mustContainNonSilence; return mustContainNonSilence;
} }
} }

View File

@@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class FeatureConfig { public class FeatureConfig {
private final int sampleRate; private final int sampleRate;
private final int featureDim; private final int featureDim;
public FeatureConfig(int sampleRate, int featureDim) { public FeatureConfig(int sampleRate, int featureDim) {
this.sampleRate = sampleRate; this.sampleRate = sampleRate;
this.featureDim = featureDim; this.featureDim = featureDim;
} }
public int getSampleRate() { public int getSampleRate() {
return sampleRate; return sampleRate;
} }
public int getFeatureDim() { public int getFeatureDim() {
return featureDim; return featureDim;
} }
} }

View File

@@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class OnlineLMConfig { public class OnlineLMConfig {
private final String model; private final String model;
private final float scale; private final float scale;
public OnlineLMConfig(String model, float scale) { public OnlineLMConfig(String model, float scale) {
this.model = model; this.model = model;
this.scale = scale; this.scale = scale;
} }
public String getModel() { public String getModel() {
return model; return model;
} }
public float getScale() { public float getScale() {
return scale; return scale;
} }
} }

View File

@@ -5,47 +5,51 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class OnlineModelConfig { public class OnlineModelConfig {
private final OnlineParaformerModelConfig paraformer; private final OnlineParaformerModelConfig paraformer;
private final OnlineTransducerModelConfig transducer; private final OnlineTransducerModelConfig transducer;
private final String tokens; private final OnlineZipformer2CtcModelConfig zipformer2Ctc;
private final int numThreads; private final String tokens;
private final boolean debug; private final int numThreads;
private final String provider = "cpu"; private final boolean debug;
private String modelType = ""; private final String provider = "cpu";
private String modelType = "";
public OnlineModelConfig( public OnlineModelConfig(
String tokens, String tokens,
int numThreads, int numThreads,
boolean debug, boolean debug,
String modelType, String modelType,
OnlineParaformerModelConfig paraformer, OnlineParaformerModelConfig paraformer,
OnlineTransducerModelConfig transducer) { OnlineTransducerModelConfig transducer,
OnlineZipformer2CtcModelConfig zipformer2Ctc
) {
this.tokens = tokens; this.tokens = tokens;
this.numThreads = numThreads; this.numThreads = numThreads;
this.debug = debug; this.debug = debug;
this.modelType = modelType; this.modelType = modelType;
this.paraformer = paraformer; this.paraformer = paraformer;
this.transducer = transducer; this.transducer = transducer;
} this.zipformer2Ctc = zipformer2Ctc;
}
public OnlineParaformerModelConfig getParaformer() { public OnlineParaformerModelConfig getParaformer() {
return paraformer; return paraformer;
} }
public OnlineTransducerModelConfig getTransducer() { public OnlineTransducerModelConfig getTransducer() {
return transducer; return transducer;
} }
public String getTokens() { public String getTokens() {
return tokens; return tokens;
} }
public int getNumThreads() { public int getNumThreads() {
return numThreads; return numThreads;
} }
public boolean getDebug() { public boolean getDebug() {
return debug; return debug;
} }
} }

View File

@@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class OnlineParaformerModelConfig { public class OnlineParaformerModelConfig {
private final String encoder; private final String encoder;
private final String decoder; private final String decoder;
public OnlineParaformerModelConfig(String encoder, String decoder) { public OnlineParaformerModelConfig(String encoder, String decoder) {
this.encoder = encoder; this.encoder = encoder;
this.decoder = decoder; this.decoder = decoder;
} }
public String getEncoder() { public String getEncoder() {
return encoder; return encoder;
} }
public String getDecoder() { public String getDecoder() {
return decoder; return decoder;
} }
} }

View File

@@ -32,336 +32,345 @@ usage example:
*/ */
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
import java.io.*; import java.io.BufferedInputStream;
import java.util.*; import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
public class OnlineRecognizer { public class OnlineRecognizer {
private long ptr = 0; // this is the asr engine ptrss private long ptr = 0; // this is the asr engine ptrss
private int sampleRate = 16000; private int sampleRate = 16000;
// load config file for OnlineRecognizer // load config file for OnlineRecognizer
public OnlineRecognizer(String modelCfgPath) { public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath); Map<String, String> proMap = this.readProperties(modelCfgPath);
try { try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate; this.sampleRate = sampleRate;
EndpointRule rule1 = EndpointRule rule1 =
new EndpointRule( new EndpointRule(
false, false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F); 0.0F);
EndpointRule rule2 = EndpointRule rule2 =
new EndpointRule( new EndpointRule(
true, true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F); 0.0F);
EndpointRule rule3 = EndpointRule rule3 =
new EndpointRule( new EndpointRule(
false, false,
0.0F, 0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig( new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg = OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig( new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(), proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim()); proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg = OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
new OnlineModelConfig( OnlineModelConfig modelCfg =
proMap.getOrDefault("tokens", "").trim(), new OnlineModelConfig(
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), proMap.getOrDefault("tokens", "").trim(),
false, Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
proMap.getOrDefault("model_type", "zipformer").trim(), false,
modelParaCfg, proMap.getOrDefault("model_type", "zipformer").trim(),
modelTranCfg); modelParaCfg,
FeatureConfig featConfig = modelTranCfg, zipformer2CtcConfig);
new FeatureConfig( FeatureConfig featConfig =
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); new FeatureConfig(
OnlineLMConfig onlineLmConfig = sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
new OnlineLMConfig( OnlineLMConfig onlineLmConfig =
proMap.getOrDefault("lm_model", "").trim(), new OnlineLMConfig(
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg = OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig( new OnlineRecognizerConfig(
featConfig, featConfig,
modelCfg, modelCfg,
endCfg, endCfg,
onlineLmConfig, onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(), proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg); this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
} catch (Exception e) { } catch (Exception e) {
System.err.println(e); System.err.println(e);
}
} }
}
// use for android asset_manager ANDROID_API__ >= 9 // use for android asset_manager ANDROID_API__ >= 9
public OnlineRecognizer(Object assetManager, String modelCfgPath) { public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath); Map<String, String> proMap = this.readProperties(modelCfgPath);
try { try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate; this.sampleRate = sampleRate;
EndpointRule rule1 = EndpointRule rule1 =
new EndpointRule( new EndpointRule(
false, false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F); 0.0F);
EndpointRule rule2 = EndpointRule rule2 =
new EndpointRule( new EndpointRule(
true, true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F); 0.0F);
EndpointRule rule3 = EndpointRule rule3 =
new EndpointRule( new EndpointRule(
false, false,
0.0F, 0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig( new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg = OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig( new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(), proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim()); proMap.getOrDefault("joiner", "").trim());
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg = OnlineModelConfig modelCfg =
new OnlineModelConfig( new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(), proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false, false,
proMap.getOrDefault("model_type", "zipformer").trim(), proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg, modelParaCfg,
modelTranCfg); modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig = FeatureConfig featConfig =
new FeatureConfig( new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig = OnlineLMConfig onlineLmConfig =
new OnlineLMConfig( new OnlineLMConfig(
proMap.getOrDefault("lm_model", "").trim(), proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg = OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig( new OnlineRecognizerConfig(
featConfig, featConfig,
modelCfg, modelCfg,
endCfg, endCfg,
onlineLmConfig, onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(), proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg); this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
} catch (Exception e) { } catch (Exception e) {
System.err.println(e); System.err.println(e);
}
} }
}
// set onlineRecognizer by parameter // set onlineRecognizer by parameter
public OnlineRecognizer( public OnlineRecognizer(
String tokens, String tokens,
String encoder, String encoder,
String decoder, String decoder,
String joiner, String joiner,
int numThreads, int numThreads,
int sampleRate, int sampleRate,
int featureDim, int featureDim,
boolean enableEndpointDetection, boolean enableEndpointDetection,
float rule1MinTrailingSilence, float rule1MinTrailingSilence,
float rule2MinTrailingSilence, float rule2MinTrailingSilence,
float rule3MinUtteranceLength, float rule3MinUtteranceLength,
String decodingMethod, String decodingMethod,
String lm_model, String lm_model,
float lm_scale, float lm_scale,
int maxActivePaths, int maxActivePaths,
String hotwordsFile, String hotwordsFile,
float hotwordsScore, float hotwordsScore,
String modelType) { String modelType) {
this.sampleRate = sampleRate; this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
OnlineTransducerModelConfig modelTranCfg = OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(encoder, decoder, joiner); new OnlineTransducerModelConfig(encoder, decoder, joiner);
OnlineModelConfig modelCfg = OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg); OnlineModelConfig modelCfg =
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig);
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineRecognizerConfig rcgCfg = OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
new OnlineRecognizerConfig( OnlineRecognizerConfig rcgCfg =
featConfig, new OnlineRecognizerConfig(
modelCfg, featConfig,
endCfg, modelCfg,
onlineLmConfig, endCfg,
enableEndpointDetection, onlineLmConfig,
decodingMethod, enableEndpointDetection,
maxActivePaths, decodingMethod,
hotwordsFile, maxActivePaths,
hotwordsScore); hotwordsFile,
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 hotwordsScore);
this.ptr = createOnlineRecognizer(new Object(), rcgCfg); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
} this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
private Map<String, String> readProperties(String modelCfgPath) {
// read and parse config file
Properties props = new Properties();
Map<String, String> proMap = new HashMap<>();
try {
File file = new File(modelCfgPath);
if (!file.exists()) {
System.out.println("model cfg file not exists!");
System.exit(0);
}
InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath));
props.load(in);
Enumeration en = props.propertyNames();
while (en.hasMoreElements()) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
}
} catch (Exception e) {
e.printStackTrace();
} }
return proMap;
}
public void decodeStream(OnlineStream s) throws Exception { public static float[] readWavFile(String fileName) {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); // read data from the filename
long streamPtr = s.getPtr(); Object[] wavdata = readWave(fileName);
if (streamPtr == 0) throw new Exception("null exception for stream ptr"); Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
// when feeded samples to engine, call DecodeStream to let it process
decodeStream(this.ptr, streamPtr);
}
public void decodeStreams(OnlineStream[] ssOjb) throws Exception { float[] floatData = (float[]) data;
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
// decode for multiple streams return floatData;
long[] ss = new long[ssOjb.length];
for (int i = 0; i < ssOjb.length; i++) {
ss[i] = ssOjb[i].getPtr();
if (ss[i] == 0) throw new Exception("null exception for stream ptr");
} }
decodeStreams(this.ptr, ss);
}
public boolean isReady(OnlineStream s) throws Exception { // load the libsherpa-onnx-jni.so lib
// whether the engine is ready for decode public static void loadSoLib(String soPath) {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); // load libsherpa-onnx-jni.so lib from the path
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isReady(this.ptr, streamPtr);
}
public String getResult(OnlineStream s) throws Exception { System.out.println("so lib path=" + soPath + "\n");
// get text from the engine System.load(soPath.trim());
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); System.out.println("load so lib succeed\n");
long streamPtr = s.getPtr(); }
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return getResult(this.ptr, streamPtr);
}
public boolean isEndpoint(OnlineStream s) throws Exception { public static void setSoPath(String soPath) {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); OnlineRecognizer.loadSoLib(soPath);
long streamPtr = s.getPtr(); OnlineStream.loadSoLib(soPath);
if (streamPtr == 0) throw new Exception("null exception for stream ptr"); }
return isEndpoint(this.ptr, streamPtr);
}
public void reSet(OnlineStream s) throws Exception { private static native Object[] readWave(String fileName); // static
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr, streamPtr);
}
public OnlineStream createStream() throws Exception { private Map<String, String> readProperties(String modelCfgPath) {
// create one stream for data to feed in // read and parse config file
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); Properties props = new Properties();
long streamPtr = createStream(this.ptr); Map<String, String> proMap = new HashMap<>();
OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate); try {
return stream; File file = new File(modelCfgPath);
} if (!file.exists()) {
System.out.println("model cfg file not exists!");
System.exit(0);
}
InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath));
props.load(in);
Enumeration en = props.propertyNames();
while (en.hasMoreElements()) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
}
public static float[] readWavFile(String fileName) { } catch (Exception e) {
// read data from the filename e.printStackTrace();
Object[] wavdata = readWave(fileName); }
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate return proMap;
}
float[] floatData = (float[]) data; public void decodeStream(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
// when feeded samples to engine, call DecodeStream to let it process
decodeStream(this.ptr, streamPtr);
}
return floatData; public void decodeStreams(OnlineStream[] ssOjb) throws Exception {
} if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
// decode for multiple streams
long[] ss = new long[ssOjb.length];
for (int i = 0; i < ssOjb.length; i++) {
ss[i] = ssOjb[i].getPtr();
if (ss[i] == 0) throw new Exception("null exception for stream ptr");
}
decodeStreams(this.ptr, ss);
}
// load the libsherpa-onnx-jni.so lib public boolean isReady(OnlineStream s) throws Exception {
public static void loadSoLib(String soPath) { // whether the engine is ready for decode
// load libsherpa-onnx-jni.so lib from the path if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isReady(this.ptr, streamPtr);
}
System.out.println("so lib path=" + soPath + "\n"); public String getResult(OnlineStream s) throws Exception {
System.load(soPath.trim()); // get text from the engine
System.out.println("load so lib succeed\n"); if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
} long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return getResult(this.ptr, streamPtr);
}
public static void setSoPath(String soPath) { public boolean isEndpoint(OnlineStream s) throws Exception {
OnlineRecognizer.loadSoLib(soPath); if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
OnlineStream.loadSoLib(soPath); long streamPtr = s.getPtr();
} if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isEndpoint(this.ptr, streamPtr);
}
protected void finalize() throws Throwable { public void reSet(OnlineStream s) throws Exception {
release(); if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
} long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr, streamPtr);
}
// recognizer release, you'd better call it manually if not use anymore public OnlineStream createStream() throws Exception {
public void release() { // create one stream for data to feed in
if (this.ptr == 0) return; if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
deleteOnlineRecognizer(this.ptr); long streamPtr = createStream(this.ptr);
this.ptr = 0; OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate);
} return stream;
}
// stream release, you'd better call it manually if not use anymore protected void finalize() throws Throwable {
public void releaseStream(OnlineStream s) { release();
s.release(); }
}
// JNI interface libsherpa-onnx-jni.so // recognizer release, you'd better call it manually if not use anymore
public void release() {
if (this.ptr == 0) return;
deleteOnlineRecognizer(this.ptr);
this.ptr = 0;
}
private static native Object[] readWave(String fileName); // static // JNI interface libsherpa-onnx-jni.so
private native String getResult(long ptr, long streamPtr); // stream release, you'd better call it manually if not use anymore
public void releaseStream(OnlineStream s) {
s.release();
}
private native void decodeStream(long ptr, long streamPtr); private native String getResult(long ptr, long streamPtr);
private native void decodeStreams(long ptr, long[] ssPtr); private native void decodeStream(long ptr, long streamPtr);
private native boolean isReady(long ptr, long streamPtr); private native void decodeStreams(long ptr, long[] ssPtr);
// first parameter keep for android asset_manager ANDROID_API__ >= 9 private native boolean isReady(long ptr, long streamPtr);
private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config);
private native long createStream(long ptr); // first parameter keep for android asset_manager ANDROID_API__ >= 9
private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config);
private native void deleteOnlineRecognizer(long ptr); private native long createStream(long ptr);
private native boolean isEndpoint(long ptr, long streamPtr); private native void deleteOnlineRecognizer(long ptr);
private native void reSet(long ptr, long streamPtr); private native boolean isEndpoint(long ptr, long streamPtr);
private native void reSet(long ptr, long streamPtr);
} }

View File

@@ -5,62 +5,62 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class OnlineRecognizerConfig { public class OnlineRecognizerConfig {
private final FeatureConfig featConfig; private final FeatureConfig featConfig;
private final OnlineModelConfig modelConfig; private final OnlineModelConfig modelConfig;
private final EndpointConfig endpointConfig; private final EndpointConfig endpointConfig;
private final OnlineLMConfig lmConfig; private final OnlineLMConfig lmConfig;
private final boolean enableEndpoint; private final boolean enableEndpoint;
private final String decodingMethod; private final String decodingMethod;
private final int maxActivePaths; private final int maxActivePaths;
private final String hotwordsFile; private final String hotwordsFile;
private final float hotwordsScore; private final float hotwordsScore;
public OnlineRecognizerConfig( public OnlineRecognizerConfig(
FeatureConfig featConfig, FeatureConfig featConfig,
OnlineModelConfig modelConfig, OnlineModelConfig modelConfig,
EndpointConfig endpointConfig, EndpointConfig endpointConfig,
OnlineLMConfig lmConfig, OnlineLMConfig lmConfig,
boolean enableEndpoint, boolean enableEndpoint,
String decodingMethod, String decodingMethod,
int maxActivePaths, int maxActivePaths,
String hotwordsFile, String hotwordsFile,
float hotwordsScore) { float hotwordsScore) {
this.featConfig = featConfig; this.featConfig = featConfig;
this.modelConfig = modelConfig; this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig; this.endpointConfig = endpointConfig;
this.lmConfig = lmConfig; this.lmConfig = lmConfig;
this.enableEndpoint = enableEndpoint; this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod; this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths; this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile; this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore; this.hotwordsScore = hotwordsScore;
} }
public OnlineLMConfig getLmConfig() { public OnlineLMConfig getLmConfig() {
return lmConfig; return lmConfig;
} }
public FeatureConfig getFeatConfig() { public FeatureConfig getFeatConfig() {
return featConfig; return featConfig;
} }
public OnlineModelConfig getModelConfig() { public OnlineModelConfig getModelConfig() {
return modelConfig; return modelConfig;
} }
public EndpointConfig getEndpointConfig() { public EndpointConfig getEndpointConfig() {
return endpointConfig; return endpointConfig;
} }
public boolean isEnableEndpoint() { public boolean isEnableEndpoint() {
return enableEndpoint; return enableEndpoint;
} }
public String getDecodingMethod() { public String getDecodingMethod() {
return decodingMethod; return decodingMethod;
} }
public int getMaxActivePaths() { public int getMaxActivePaths() {
return maxActivePaths; return maxActivePaths;
} }
} }

View File

@@ -4,83 +4,81 @@
// Stream is used for feeding data to the asr engine // Stream is used for feeding data to the asr engine
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
import java.io.*;
import java.util.*;
public class OnlineStream { public class OnlineStream {
private long ptr = 0; // this is the stream ptr private long ptr = 0; // this is the stream ptr
private int sampleRate = 16000; private int sampleRate = 16000;
// assign ptr to this stream in construction
public OnlineStream(long ptr, int sampleRate) {
this.ptr = ptr;
this.sampleRate = sampleRate;
}
public long getPtr() { // assign ptr to this stream in construction
return ptr; public OnlineStream(long ptr, int sampleRate) {
} this.ptr = ptr;
this.sampleRate = sampleRate;
}
public void acceptWaveform(float[] samples) throws Exception { public static void loadSoLib(String soPath) {
if (this.ptr == 0) throw new Exception("null exception for stream ptr"); // load .so lib from the path
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
}
// feed wave data to asr engine public long getPtr() {
acceptWaveform(this.ptr, this.sampleRate, samples); return ptr;
} }
public void inputFinished() { public void acceptWaveform(float[] samples) throws Exception {
// add some tail padding if (this.ptr == 0) throw new Exception("null exception for stream ptr");
int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
float tailPaddings[] = new float[padLen]; // default value is 0
acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
// tell the engine all data are feeded // feed wave data to asr engine
inputFinished(this.ptr); acceptWaveform(this.ptr, this.sampleRate, samples);
} }
public static void loadSoLib(String soPath) { public void inputFinished() {
// load .so lib from the path // add some tail padding
System.load(soPath.trim()); // ("sherpa-onnx-jni-java"); int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
} float[] tailPaddings = new float[padLen]; // default value is 0
acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
public void release() { // tell the engine all data are feeded
// stream object must be release after used inputFinished(this.ptr);
if (this.ptr == 0) return; }
deleteStream(this.ptr);
this.ptr = 0;
}
protected void finalize() throws Throwable { public void release() {
release(); // stream object must be release after used
} if (this.ptr == 0) return;
deleteStream(this.ptr);
this.ptr = 0;
}
public boolean isLastFrame() throws Exception { protected void finalize() throws Throwable {
if (this.ptr == 0) throw new Exception("null exception for stream ptr"); release();
return isLastFrame(this.ptr); }
}
public void reSet() throws Exception { public boolean isLastFrame() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr"); if (this.ptr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr); return isLastFrame(this.ptr);
} }
public int featureDim() throws Exception { public void reSet() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr"); if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return featureDim(this.ptr); reSet(this.ptr);
} }
// JNI interface libsherpa-onnx-jni.so public int featureDim() throws Exception {
private native void acceptWaveform(long ptr, int sampleRate, float[] samples); if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return featureDim(this.ptr);
}
private native void inputFinished(long ptr); // JNI interface libsherpa-onnx-jni.so
private native void acceptWaveform(long ptr, int sampleRate, float[] samples);
private native void deleteStream(long ptr); private native void inputFinished(long ptr);
private native int numFramesReady(long ptr); private native void deleteStream(long ptr);
private native boolean isLastFrame(long ptr); private native int numFramesReady(long ptr);
private native void reSet(long ptr); private native boolean isLastFrame(long ptr);
private native int featureDim(long ptr); private native void reSet(long ptr);
private native int featureDim(long ptr);
} }

View File

@@ -5,25 +5,25 @@
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
public class OnlineTransducerModelConfig { public class OnlineTransducerModelConfig {
private final String encoder; private final String encoder;
private final String decoder; private final String decoder;
private final String joiner; private final String joiner;
public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) { public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
this.encoder = encoder; this.encoder = encoder;
this.decoder = decoder; this.decoder = decoder;
this.joiner = joiner; this.joiner = joiner;
} }
public String getEncoder() { public String getEncoder() {
return encoder; return encoder;
} }
public String getDecoder() { public String getDecoder() {
return decoder; return decoder;
} }
public String getJoiner() { public String getJoiner() {
return joiner; return joiner;
} }
} }

View File

@@ -0,0 +1,14 @@
package com.k2fsa.sherpa.onnx;
public class OnlineZipformer2CtcModelConfig {
private final String model;
public OnlineZipformer2CtcModelConfig(String model) {
this.model = model;
}
public String getModel() {
return model;
}
}

View File

@@ -1522,8 +1522,8 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(
SHERPA_ONNX_EXTERN_C SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate,
jlong ptr, jboolean recreate, jstring keywords) { jstring keywords) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
model->Reset(recreate, p_keywords); model->Reset(recreate, p_keywords);