diff --git a/java-api-examples/modelconfig.cfg b/java-api-examples/modelconfig.cfg index 2e280778..becc0a03 100755 --- a/java-api-examples/modelconfig.cfg +++ b/java-api-examples/modelconfig.cfg @@ -4,16 +4,17 @@ feature_dim=80 rule1_min_trailing_silence=2.4 rule2_min_trailing_silence=1.2 rule3_min_utterance_length=20 -encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx -decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx -joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx -tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt +encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx +decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx +joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx +tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt num_threads=4 enable_endpoint_detection=true decoding_method=modified_beam_search max_active_paths=4 lm_model= lm_scale=0.5 +model_type=zipformer #websocket server config port=8890 diff --git a/java-api-examples/src/DecodeFile.java b/java-api-examples/src/DecodeFile.java index afbe3365..f3865a61 100644 --- a/java-api-examples/src/DecodeFile.java +++ b/java-api-examples/src/DecodeFile.java @@ -49,8 +49,9 @@ public class DecodeFile { float rule3MinUtteranceLength = 20F; String decodingMethod = "greedy_search"; int maxActivePaths = 4; - String lm_model=""; - float lm_scale=0.5F; + String lm_model = ""; + float lm_scale = 0.5F; + String modelType = "zipformer"; rcgOjb = new OnlineRecognizer( tokens, @@ -65,9 +66,10 @@ public class DecodeFile { rule2MinTrailingSilence, rule3MinUtteranceLength, decodingMethod, - lm_model, - lm_scale, - maxActivePaths); + lm_model, + lm_scale, + maxActivePaths, + modelType); streamObj = rcgOjb.createStream(); } catch (Exception e) { System.err.println(e); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 5658125d..2116320f 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -39,6 +39,7 @@ public class OnlineRecognizer { private long ptr = 0; // this is the asr engine ptrss private int sampleRate = 16000; + // load config file for OnlineRecognizer public OnlineRecognizer(String modelCfgPath) { Map proMap = this.readProperties(modelCfgPath); @@ -62,17 +63,20 @@ public class OnlineRecognizer { proMap.get("joiner").trim(), proMap.get("tokens").trim(), Integer.parseInt(proMap.get("num_threads").trim()), - false); + false, + proMap.get("model_type").trim()); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); - - OnlineRecognizerConfig rcgCfg = + OnlineLMConfig onlineLmConfig = + new OnlineLMConfig( + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); + + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( featConfig, modelCfg, endCfg, - onlineLmConfig, + onlineLmConfig, Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), proMap.get("decoding_method").trim(), Integer.parseInt(proMap.get("max_active_paths").trim())); @@ -107,18 +111,21 @@ public class OnlineRecognizer { proMap.get("joiner").trim(), proMap.get("tokens").trim(), Integer.parseInt(proMap.get("num_threads").trim()), - false); + false, + proMap.get("model_type").trim()); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); - - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); - - OnlineRecognizerConfig rcgCfg = + + OnlineLMConfig onlineLmConfig = + new OnlineLMConfig( + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); + + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( featConfig, modelCfg, endCfg, - onlineLmConfig, + onlineLmConfig, Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), proMap.get("decoding_method").trim(), Integer.parseInt(proMap.get("max_active_paths").trim())); @@ -144,21 +151,29 @@ public class OnlineRecognizer { float rule2MinTrailingSilence, float rule3MinUtteranceLength, String decodingMethod, - String lm_model, - float lm_scale, - int maxActivePaths) { + String lm_model, + float lm_scale, + int maxActivePaths, + String modelType) { this.sampleRate = sampleRate; EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); OnlineTransducerModelConfig modelCfg = - new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); + new OnlineTransducerModelConfig( + encoder, decoder, joiner, tokens, numThreads, false, modelType); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale); - OnlineRecognizerConfig rcgCfg = + OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( - featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); + featConfig, + modelCfg, + endCfg, + onlineLmConfig, + enableEndpointDetection, + decodingMethod, + maxActivePaths); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); } @@ -284,9 +299,10 @@ public class OnlineRecognizer { public void releaseStream(OnlineStream s) { s.release(); } + // JNI interface libsherpa-onnx-jni.so - private static native Object[] readWave(String fileName); // static + private static native Object[] readWave(String fileName); // static private native String getResult(long ptr, long streamPtr); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java index 1e45e371..7697bc51 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig { private final String tokens; private final int numThreads; private final boolean debug; + private final String provider = "cpu"; + private String modelType = ""; public OnlineTransducerModelConfig( - String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) { + String encoder, + String decoder, + String joiner, + String tokens, + int numThreads, + boolean debug, + String modelType) { this.encoder = encoder; this.decoder = decoder; this.joiner = joiner; this.tokens = tokens; this.numThreads = numThreads; this.debug = debug; + this.modelType = modelType; } public String getEncoder() {