Java api update for adding modelType in config class (#228)
This commit is contained in:
@@ -4,16 +4,17 @@ feature_dim=80
|
|||||||
rule1_min_trailing_silence=2.4
|
rule1_min_trailing_silence=2.4
|
||||||
rule2_min_trailing_silence=1.2
|
rule2_min_trailing_silence=1.2
|
||||||
rule3_min_utterance_length=20
|
rule3_min_utterance_length=20
|
||||||
encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
|
encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
|
||||||
decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
|
decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
|
||||||
joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
|
joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
|
||||||
tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
||||||
num_threads=4
|
num_threads=4
|
||||||
enable_endpoint_detection=true
|
enable_endpoint_detection=true
|
||||||
decoding_method=modified_beam_search
|
decoding_method=modified_beam_search
|
||||||
max_active_paths=4
|
max_active_paths=4
|
||||||
lm_model=
|
lm_model=
|
||||||
lm_scale=0.5
|
lm_scale=0.5
|
||||||
|
model_type=zipformer
|
||||||
|
|
||||||
#websocket server config
|
#websocket server config
|
||||||
port=8890
|
port=8890
|
||||||
|
|||||||
@@ -49,8 +49,9 @@ public class DecodeFile {
|
|||||||
float rule3MinUtteranceLength = 20F;
|
float rule3MinUtteranceLength = 20F;
|
||||||
String decodingMethod = "greedy_search";
|
String decodingMethod = "greedy_search";
|
||||||
int maxActivePaths = 4;
|
int maxActivePaths = 4;
|
||||||
String lm_model="";
|
String lm_model = "";
|
||||||
float lm_scale=0.5F;
|
float lm_scale = 0.5F;
|
||||||
|
String modelType = "zipformer";
|
||||||
rcgOjb =
|
rcgOjb =
|
||||||
new OnlineRecognizer(
|
new OnlineRecognizer(
|
||||||
tokens,
|
tokens,
|
||||||
@@ -65,9 +66,10 @@ public class DecodeFile {
|
|||||||
rule2MinTrailingSilence,
|
rule2MinTrailingSilence,
|
||||||
rule3MinUtteranceLength,
|
rule3MinUtteranceLength,
|
||||||
decodingMethod,
|
decodingMethod,
|
||||||
lm_model,
|
lm_model,
|
||||||
lm_scale,
|
lm_scale,
|
||||||
maxActivePaths);
|
maxActivePaths,
|
||||||
|
modelType);
|
||||||
streamObj = rcgOjb.createStream();
|
streamObj = rcgOjb.createStream();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
System.err.println(e);
|
System.err.println(e);
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ public class OnlineRecognizer {
|
|||||||
private long ptr = 0; // this is the asr engine ptrss
|
private long ptr = 0; // this is the asr engine ptrss
|
||||||
|
|
||||||
private int sampleRate = 16000;
|
private int sampleRate = 16000;
|
||||||
|
|
||||||
// load config file for OnlineRecognizer
|
// load config file for OnlineRecognizer
|
||||||
public OnlineRecognizer(String modelCfgPath) {
|
public OnlineRecognizer(String modelCfgPath) {
|
||||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||||
@@ -62,17 +63,20 @@ public class OnlineRecognizer {
|
|||||||
proMap.get("joiner").trim(),
|
proMap.get("joiner").trim(),
|
||||||
proMap.get("tokens").trim(),
|
proMap.get("tokens").trim(),
|
||||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||||
false);
|
false,
|
||||||
|
proMap.get("model_type").trim());
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
OnlineLMConfig onlineLmConfig =
|
||||||
|
new OnlineLMConfig(
|
||||||
OnlineRecognizerConfig rcgCfg =
|
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||||
|
|
||||||
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig,
|
featConfig,
|
||||||
modelCfg,
|
modelCfg,
|
||||||
endCfg,
|
endCfg,
|
||||||
onlineLmConfig,
|
onlineLmConfig,
|
||||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||||
proMap.get("decoding_method").trim(),
|
proMap.get("decoding_method").trim(),
|
||||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||||
@@ -107,18 +111,21 @@ public class OnlineRecognizer {
|
|||||||
proMap.get("joiner").trim(),
|
proMap.get("joiner").trim(),
|
||||||
proMap.get("tokens").trim(),
|
proMap.get("tokens").trim(),
|
||||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||||
false);
|
false,
|
||||||
|
proMap.get("model_type").trim());
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
|
|
||||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
OnlineLMConfig onlineLmConfig =
|
||||||
|
new OnlineLMConfig(
|
||||||
OnlineRecognizerConfig rcgCfg =
|
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||||
|
|
||||||
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig,
|
featConfig,
|
||||||
modelCfg,
|
modelCfg,
|
||||||
endCfg,
|
endCfg,
|
||||||
onlineLmConfig,
|
onlineLmConfig,
|
||||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||||
proMap.get("decoding_method").trim(),
|
proMap.get("decoding_method").trim(),
|
||||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||||
@@ -144,21 +151,29 @@ public class OnlineRecognizer {
|
|||||||
float rule2MinTrailingSilence,
|
float rule2MinTrailingSilence,
|
||||||
float rule3MinUtteranceLength,
|
float rule3MinUtteranceLength,
|
||||||
String decodingMethod,
|
String decodingMethod,
|
||||||
String lm_model,
|
String lm_model,
|
||||||
float lm_scale,
|
float lm_scale,
|
||||||
int maxActivePaths) {
|
int maxActivePaths,
|
||||||
|
String modelType) {
|
||||||
this.sampleRate = sampleRate;
|
this.sampleRate = sampleRate;
|
||||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||||
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
||||||
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
||||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||||
OnlineTransducerModelConfig modelCfg =
|
OnlineTransducerModelConfig modelCfg =
|
||||||
new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
|
new OnlineTransducerModelConfig(
|
||||||
|
encoder, decoder, joiner, tokens, numThreads, false, modelType);
|
||||||
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
||||||
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);
|
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
|
||||||
OnlineRecognizerConfig rcgCfg =
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths);
|
featConfig,
|
||||||
|
modelCfg,
|
||||||
|
endCfg,
|
||||||
|
onlineLmConfig,
|
||||||
|
enableEndpointDetection,
|
||||||
|
decodingMethod,
|
||||||
|
maxActivePaths);
|
||||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||||
}
|
}
|
||||||
@@ -284,9 +299,10 @@ public class OnlineRecognizer {
|
|||||||
public void releaseStream(OnlineStream s) {
|
public void releaseStream(OnlineStream s) {
|
||||||
s.release();
|
s.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
// JNI interface libsherpa-onnx-jni.so
|
// JNI interface libsherpa-onnx-jni.so
|
||||||
|
|
||||||
private static native Object[] readWave(String fileName); // static
|
private static native Object[] readWave(String fileName); // static
|
||||||
|
|
||||||
private native String getResult(long ptr, long streamPtr);
|
private native String getResult(long ptr, long streamPtr);
|
||||||
|
|
||||||
|
|||||||
@@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig {
|
|||||||
private final String tokens;
|
private final String tokens;
|
||||||
private final int numThreads;
|
private final int numThreads;
|
||||||
private final boolean debug;
|
private final boolean debug;
|
||||||
|
private final String provider = "cpu";
|
||||||
|
private String modelType = "";
|
||||||
|
|
||||||
public OnlineTransducerModelConfig(
|
public OnlineTransducerModelConfig(
|
||||||
String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) {
|
String encoder,
|
||||||
|
String decoder,
|
||||||
|
String joiner,
|
||||||
|
String tokens,
|
||||||
|
int numThreads,
|
||||||
|
boolean debug,
|
||||||
|
String modelType) {
|
||||||
this.encoder = encoder;
|
this.encoder = encoder;
|
||||||
this.decoder = decoder;
|
this.decoder = decoder;
|
||||||
this.joiner = joiner;
|
this.joiner = joiner;
|
||||||
this.tokens = tokens;
|
this.tokens = tokens;
|
||||||
this.numThreads = numThreads;
|
this.numThreads = numThreads;
|
||||||
this.debug = debug;
|
this.debug = debug;
|
||||||
|
this.modelType = modelType;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getEncoder() {
|
public String getEncoder() {
|
||||||
|
|||||||
Reference in New Issue
Block a user