Java api update for adding modelType in config class (#228)
This commit is contained in:
@@ -39,6 +39,7 @@ public class OnlineRecognizer {
|
||||
private long ptr = 0; // this is the asr engine ptrss
|
||||
|
||||
private int sampleRate = 16000;
|
||||
|
||||
// load config file for OnlineRecognizer
|
||||
public OnlineRecognizer(String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
@@ -62,17 +63,20 @@ public class OnlineRecognizer {
|
||||
proMap.get("joiner").trim(),
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
false);
|
||||
false,
|
||||
proMap.get("model_type").trim());
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
@@ -107,18 +111,21 @@ public class OnlineRecognizer {
|
||||
proMap.get("joiner").trim(),
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
false);
|
||||
false,
|
||||
proMap.get("model_type").trim());
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
|
||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
@@ -144,21 +151,29 @@ public class OnlineRecognizer {
|
||||
float rule2MinTrailingSilence,
|
||||
float rule3MinUtteranceLength,
|
||||
String decodingMethod,
|
||||
String lm_model,
|
||||
float lm_scale,
|
||||
int maxActivePaths) {
|
||||
String lm_model,
|
||||
float lm_scale,
|
||||
int maxActivePaths,
|
||||
String modelType) {
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
||||
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
OnlineTransducerModelConfig modelCfg =
|
||||
new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
|
||||
new OnlineTransducerModelConfig(
|
||||
encoder, decoder, joiner, tokens, numThreads, false, modelType);
|
||||
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths);
|
||||
featConfig,
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
enableEndpointDetection,
|
||||
decodingMethod,
|
||||
maxActivePaths);
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
}
|
||||
@@ -284,9 +299,10 @@ public class OnlineRecognizer {
|
||||
public void releaseStream(OnlineStream s) {
|
||||
s.release();
|
||||
}
|
||||
|
||||
// JNI interface libsherpa-onnx-jni.so
|
||||
|
||||
private static native Object[] readWave(String fileName); // static
|
||||
private static native Object[] readWave(String fileName); // static
|
||||
|
||||
private native String getResult(long ptr, long streamPtr);
|
||||
|
||||
|
||||
@@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig {
|
||||
private final String tokens;
|
||||
private final int numThreads;
|
||||
private final boolean debug;
|
||||
private final String provider = "cpu";
|
||||
private String modelType = "";
|
||||
|
||||
public OnlineTransducerModelConfig(
|
||||
String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) {
|
||||
String encoder,
|
||||
String decoder,
|
||||
String joiner,
|
||||
String tokens,
|
||||
int numThreads,
|
||||
boolean debug,
|
||||
String modelType) {
|
||||
this.encoder = encoder;
|
||||
this.decoder = decoder;
|
||||
this.joiner = joiner;
|
||||
this.tokens = tokens;
|
||||
this.numThreads = numThreads;
|
||||
this.debug = debug;
|
||||
this.modelType = modelType;
|
||||
}
|
||||
|
||||
public String getEncoder() {
|
||||
|
||||
Reference in New Issue
Block a user