From fb04366179269413adc7361f72c0d52353699e62 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 26 Feb 2024 13:49:37 +0800 Subject: [PATCH] Fix #608 (#610) Fix java tests. --- java-api-examples/.gitignore | 2 + java-api-examples/Makefile | 18 +- sherpa-onnx/java-api/.gitignore | 2 + .../com/k2fsa/sherpa/onnx/EndpointConfig.java | 34 +- .../com/k2fsa/sherpa/onnx/EndpointRule.java | 36 +- .../com/k2fsa/sherpa/onnx/FeatureConfig.java | 24 +- .../com/k2fsa/sherpa/onnx/OnlineLMConfig.java | 24 +- .../k2fsa/sherpa/onnx/OnlineModelConfig.java | 76 +-- .../onnx/OnlineParaformerModelConfig.java | 24 +- .../k2fsa/sherpa/onnx/OnlineRecognizer.java | 585 +++++++++--------- .../sherpa/onnx/OnlineRecognizerConfig.java | 100 +-- .../com/k2fsa/sherpa/onnx/OnlineStream.java | 116 ++-- .../onnx/OnlineTransducerModelConfig.java | 34 +- .../onnx/OnlineZipformer2CtcModelConfig.java | 14 + sherpa-onnx/jni/jni.cc | 4 +- 15 files changed, 561 insertions(+), 532 deletions(-) create mode 100644 java-api-examples/.gitignore create mode 100644 sherpa-onnx/java-api/.gitignore create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java diff --git a/java-api-examples/.gitignore b/java-api-examples/.gitignore new file mode 100644 index 00000000..0c17ace0 --- /dev/null +++ b/java-api-examples/.gitignore @@ -0,0 +1,2 @@ +lib +hs_err* diff --git a/java-api-examples/Makefile b/java-api-examples/Makefile index 4643ca74..619d6b43 100755 --- a/java-api-examples/Makefile +++ b/java-api-examples/Makefile @@ -9,10 +9,11 @@ LIB_FILES = \ $(LIB_SRC_DIR)/OnlineLMConfig.java \ $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ $(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \ + $(LIB_SRC_DIR)/OnlineZipformer2CtcModelConfig.java \ $(LIB_SRC_DIR)/OnlineModelConfig.java \ $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ $(LIB_SRC_DIR)/OnlineStream.java \ - $(LIB_SRC_DIR)/OnlineRecognizer.java \ + $(LIB_SRC_DIR)/OnlineRecognizer.java WEBSOCKET_DIR:= ./src/websocketsrv WEBSOCKET_FILES = \ @@ -42,10 +43,10 @@ vpath %.java src buildfile: - $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE) + $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE) buildmic: - $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic) + $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic) rebuild: clean all @@ -63,8 +64,8 @@ clean: mkdir -p $(BUILD_DIR) mkdir -p ./lib -runfile: - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav +runfile: packjar buildfile + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav runhotwords: java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav @@ -85,8 +86,7 @@ buildlib: $(LIB_FILES:.java=.class) %.class: %.java - - $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< + $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< buildwebsocket: $(WEBSOCKET_FILES:.java=.class) @@ -95,7 +95,7 @@ buildwebsocket: $(WEBSOCKET_FILES:.java=.class) $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< -packjar: - jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . +packjar: buildlib + jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . all: clean buildlib packjar buildfile buildmic downjar buildwebsocket diff --git a/sherpa-onnx/java-api/.gitignore b/sherpa-onnx/java-api/.gitignore new file mode 100644 index 00000000..4934677e --- /dev/null +++ b/sherpa-onnx/java-api/.gitignore @@ -0,0 +1,2 @@ +.idea +java-api.iml diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java index 5f4b6d16..41c1c919 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java @@ -5,25 +5,25 @@ package com.k2fsa.sherpa.onnx; public class EndpointConfig { - private final EndpointRule rule1; - private final EndpointRule rule2; - private final EndpointRule rule3; + private final EndpointRule rule1; + private final EndpointRule rule2; + private final EndpointRule rule3; - public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) { - this.rule1 = rule1; - this.rule2 = rule2; - this.rule3 = rule3; - } + public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) { + this.rule1 = rule1; + this.rule2 = rule2; + this.rule3 = rule3; + } - public EndpointRule getRule1() { - return rule1; - } + public EndpointRule getRule1() { + return rule1; + } - public EndpointRule getRule2() { - return rule2; - } + public EndpointRule getRule2() { + return rule2; + } - public EndpointRule getRule3() { - return rule3; - } + public EndpointRule getRule3() { + return rule3; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java index 5a1714f6..7abcc7c5 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java @@ -5,26 +5,26 @@ package com.k2fsa.sherpa.onnx; public class EndpointRule { - private final boolean mustContainNonSilence; - private final float minTrailingSilence; - private final float minUtteranceLength; + private final boolean mustContainNonSilence; + private final float minTrailingSilence; + private final float minUtteranceLength; - public EndpointRule( - boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) { - this.mustContainNonSilence = mustContainNonSilence; - this.minTrailingSilence = minTrailingSilence; - this.minUtteranceLength = minUtteranceLength; - } + public EndpointRule( + boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) { + this.mustContainNonSilence = mustContainNonSilence; + this.minTrailingSilence = minTrailingSilence; + this.minUtteranceLength = minUtteranceLength; + } - public float getMinTrailingSilence() { - return minTrailingSilence; - } + public float getMinTrailingSilence() { + return minTrailingSilence; + } - public float getMinUtteranceLength() { - return minUtteranceLength; - } + public float getMinUtteranceLength() { + return minUtteranceLength; + } - public boolean getMustContainNonSilence() { - return mustContainNonSilence; - } + public boolean getMustContainNonSilence() { + return mustContainNonSilence; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java index 069b7897..381c28ac 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java @@ -5,19 +5,19 @@ package com.k2fsa.sherpa.onnx; public class FeatureConfig { - private final int sampleRate; - private final int featureDim; + private final int sampleRate; + private final int featureDim; - public FeatureConfig(int sampleRate, int featureDim) { - this.sampleRate = sampleRate; - this.featureDim = featureDim; - } + public FeatureConfig(int sampleRate, int featureDim) { + this.sampleRate = sampleRate; + this.featureDim = featureDim; + } - public int getSampleRate() { - return sampleRate; - } + public int getSampleRate() { + return sampleRate; + } - public int getFeatureDim() { - return featureDim; - } + public int getFeatureDim() { + return featureDim; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java index 7474a299..e94ca965 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java @@ -5,19 +5,19 @@ package com.k2fsa.sherpa.onnx; public class OnlineLMConfig { - private final String model; - private final float scale; + private final String model; + private final float scale; - public OnlineLMConfig(String model, float scale) { - this.model = model; - this.scale = scale; - } + public OnlineLMConfig(String model, float scale) { + this.model = model; + this.scale = scale; + } - public String getModel() { - return model; - } + public String getModel() { + return model; + } - public float getScale() { - return scale; - } + public float getScale() { + return scale; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java index 42e0a99e..eddf7361 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java @@ -5,47 +5,51 @@ package com.k2fsa.sherpa.onnx; public class OnlineModelConfig { - private final OnlineParaformerModelConfig paraformer; - private final OnlineTransducerModelConfig transducer; - private final String tokens; - private final int numThreads; - private final boolean debug; - private final String provider = "cpu"; - private String modelType = ""; + private final OnlineParaformerModelConfig paraformer; + private final OnlineTransducerModelConfig transducer; + private final OnlineZipformer2CtcModelConfig zipformer2Ctc; + private final String tokens; + private final int numThreads; + private final boolean debug; + private final String provider = "cpu"; + private String modelType = ""; - public OnlineModelConfig( - String tokens, - int numThreads, - boolean debug, - String modelType, - OnlineParaformerModelConfig paraformer, - OnlineTransducerModelConfig transducer) { + public OnlineModelConfig( + String tokens, + int numThreads, + boolean debug, + String modelType, + OnlineParaformerModelConfig paraformer, + OnlineTransducerModelConfig transducer, + OnlineZipformer2CtcModelConfig zipformer2Ctc + ) { - this.tokens = tokens; - this.numThreads = numThreads; - this.debug = debug; - this.modelType = modelType; - this.paraformer = paraformer; - this.transducer = transducer; - } + this.tokens = tokens; + this.numThreads = numThreads; + this.debug = debug; + this.modelType = modelType; + this.paraformer = paraformer; + this.transducer = transducer; + this.zipformer2Ctc = zipformer2Ctc; + } - public OnlineParaformerModelConfig getParaformer() { - return paraformer; - } + public OnlineParaformerModelConfig getParaformer() { + return paraformer; + } - public OnlineTransducerModelConfig getTransducer() { - return transducer; - } + public OnlineTransducerModelConfig getTransducer() { + return transducer; + } - public String getTokens() { - return tokens; - } + public String getTokens() { + return tokens; + } - public int getNumThreads() { - return numThreads; - } + public int getNumThreads() { + return numThreads; + } - public boolean getDebug() { - return debug; - } + public boolean getDebug() { + return debug; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java index c7643f6e..2f7017a0 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java @@ -5,19 +5,19 @@ package com.k2fsa.sherpa.onnx; public class OnlineParaformerModelConfig { - private final String encoder; - private final String decoder; + private final String encoder; + private final String decoder; - public OnlineParaformerModelConfig(String encoder, String decoder) { - this.encoder = encoder; - this.decoder = decoder; - } + public OnlineParaformerModelConfig(String encoder, String decoder) { + this.encoder = encoder; + this.decoder = decoder; + } - public String getEncoder() { - return encoder; - } + public String getEncoder() { + return encoder; + } - public String getDecoder() { - return decoder; - } + public String getDecoder() { + return decoder; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index d064c75d..15f07b07 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -32,336 +32,345 @@ usage example: */ package com.k2fsa.sherpa.onnx; -import java.io.*; -import java.util.*; +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; public class OnlineRecognizer { - private long ptr = 0; // this is the asr engine ptrss + private long ptr = 0; // this is the asr engine ptrss - private int sampleRate = 16000; + private int sampleRate = 16000; - // load config file for OnlineRecognizer - public OnlineRecognizer(String modelCfgPath) { - Map proMap = this.readProperties(modelCfgPath); - try { - int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); - this.sampleRate = sampleRate; - EndpointRule rule1 = - new EndpointRule( - false, - Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), - 0.0F); - EndpointRule rule2 = - new EndpointRule( - true, - Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), - 0.0F); - EndpointRule rule3 = - new EndpointRule( - false, - 0.0F, - Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + // load config file for OnlineRecognizer + public OnlineRecognizer(String modelCfgPath) { + Map proMap = this.readProperties(modelCfgPath); + try { + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); + this.sampleRate = sampleRate; + EndpointRule rule1 = + new EndpointRule( + false, + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), + 0.0F); + EndpointRule rule2 = + new EndpointRule( + true, + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), + 0.0F); + EndpointRule rule3 = + new EndpointRule( + false, + 0.0F, + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = - new OnlineParaformerModelConfig( - proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig( - proMap.getOrDefault("encoder", "").trim(), - proMap.getOrDefault("decoder", "").trim(), - proMap.getOrDefault("joiner", "").trim()); - OnlineModelConfig modelCfg = - new OnlineModelConfig( - proMap.getOrDefault("tokens", "").trim(), - Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), - false, - proMap.getOrDefault("model_type", "zipformer").trim(), - modelParaCfg, - modelTranCfg); - FeatureConfig featConfig = - new FeatureConfig( - sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); - OnlineLMConfig onlineLmConfig = - new OnlineLMConfig( - proMap.getOrDefault("lm_model", "").trim(), - Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); + OnlineParaformerModelConfig modelParaCfg = + new OnlineParaformerModelConfig( + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); + OnlineTransducerModelConfig modelTranCfg = + new OnlineTransducerModelConfig( + proMap.getOrDefault("encoder", "").trim(), + proMap.getOrDefault("decoder", "").trim(), + proMap.getOrDefault("joiner", "").trim()); + OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); + OnlineModelConfig modelCfg = + new OnlineModelConfig( + proMap.getOrDefault("tokens", "").trim(), + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), + false, + proMap.getOrDefault("model_type", "zipformer").trim(), + modelParaCfg, + modelTranCfg, zipformer2CtcConfig); + FeatureConfig featConfig = + new FeatureConfig( + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); + OnlineLMConfig onlineLmConfig = + new OnlineLMConfig( + proMap.getOrDefault("lm_model", "").trim(), + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), - proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), - Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), - proMap.getOrDefault("hotwords_file", "").trim(), - Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(new Object(), rcgCfg); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + onlineLmConfig, + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), + proMap.getOrDefault("hotwords_file", "").trim(), + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(new Object(), rcgCfg); - } catch (Exception e) { - System.err.println(e); + } catch (Exception e) { + System.err.println(e); + } } - } - // use for android asset_manager ANDROID_API__ >= 9 - public OnlineRecognizer(Object assetManager, String modelCfgPath) { - Map proMap = this.readProperties(modelCfgPath); - try { - int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); - this.sampleRate = sampleRate; - EndpointRule rule1 = - new EndpointRule( - false, - Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), - 0.0F); - EndpointRule rule2 = - new EndpointRule( - true, - Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), - 0.0F); - EndpointRule rule3 = - new EndpointRule( - false, - 0.0F, - Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = - new OnlineParaformerModelConfig( - proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig( - proMap.getOrDefault("encoder", "").trim(), - proMap.getOrDefault("decoder", "").trim(), - proMap.getOrDefault("joiner", "").trim()); + // use for android asset_manager ANDROID_API__ >= 9 + public OnlineRecognizer(Object assetManager, String modelCfgPath) { + Map proMap = this.readProperties(modelCfgPath); + try { + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); + this.sampleRate = sampleRate; + EndpointRule rule1 = + new EndpointRule( + false, + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), + 0.0F); + EndpointRule rule2 = + new EndpointRule( + true, + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), + 0.0F); + EndpointRule rule3 = + new EndpointRule( + false, + 0.0F, + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + OnlineParaformerModelConfig modelParaCfg = + new OnlineParaformerModelConfig( + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); + OnlineTransducerModelConfig modelTranCfg = + new OnlineTransducerModelConfig( + proMap.getOrDefault("encoder", "").trim(), + proMap.getOrDefault("decoder", "").trim(), + proMap.getOrDefault("joiner", "").trim()); + OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); - OnlineModelConfig modelCfg = - new OnlineModelConfig( - proMap.getOrDefault("tokens", "").trim(), - Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), - false, - proMap.getOrDefault("model_type", "zipformer").trim(), - modelParaCfg, - modelTranCfg); - FeatureConfig featConfig = - new FeatureConfig( - sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); + OnlineModelConfig modelCfg = + new OnlineModelConfig( + proMap.getOrDefault("tokens", "").trim(), + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), + false, + proMap.getOrDefault("model_type", "zipformer").trim(), + modelParaCfg, + modelTranCfg, zipformer2CtcConfig); + FeatureConfig featConfig = + new FeatureConfig( + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); - OnlineLMConfig onlineLmConfig = - new OnlineLMConfig( - proMap.getOrDefault("lm_model", "").trim(), - Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); + OnlineLMConfig onlineLmConfig = + new OnlineLMConfig( + proMap.getOrDefault("lm_model", "").trim(), + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), - proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), - Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), - proMap.getOrDefault("hotwords_file", "").trim(), - Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(assetManager, rcgCfg); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + onlineLmConfig, + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), + proMap.getOrDefault("hotwords_file", "").trim(), + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(assetManager, rcgCfg); - } catch (Exception e) { - System.err.println(e); + } catch (Exception e) { + System.err.println(e); + } } - } - // set onlineRecognizer by parameter - public OnlineRecognizer( - String tokens, - String encoder, - String decoder, - String joiner, - int numThreads, - int sampleRate, - int featureDim, - boolean enableEndpointDetection, - float rule1MinTrailingSilence, - float rule2MinTrailingSilence, - float rule3MinUtteranceLength, - String decodingMethod, - String lm_model, - float lm_scale, - int maxActivePaths, - String hotwordsFile, - float hotwordsScore, - String modelType) { - this.sampleRate = sampleRate; - EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); - EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); - EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig(encoder, decoder, joiner); - OnlineModelConfig modelCfg = - new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg); - FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); - OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - enableEndpointDetection, - decodingMethod, - maxActivePaths, - hotwordsFile, - hotwordsScore); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(new Object(), rcgCfg); - } - - private Map readProperties(String modelCfgPath) { - // read and parse config file - Properties props = new Properties(); - Map proMap = new HashMap<>(); - try { - File file = new File(modelCfgPath); - if (!file.exists()) { - System.out.println("model cfg file not exists!"); - System.exit(0); - } - InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath)); - props.load(in); - Enumeration en = props.propertyNames(); - while (en.hasMoreElements()) { - String key = (String) en.nextElement(); - String Property = props.getProperty(key); - proMap.put(key, Property); - } - - } catch (Exception e) { - e.printStackTrace(); + // set onlineRecognizer by parameter + public OnlineRecognizer( + String tokens, + String encoder, + String decoder, + String joiner, + int numThreads, + int sampleRate, + int featureDim, + boolean enableEndpointDetection, + float rule1MinTrailingSilence, + float rule2MinTrailingSilence, + float rule3MinUtteranceLength, + String decodingMethod, + String lm_model, + float lm_scale, + int maxActivePaths, + String hotwordsFile, + float hotwordsScore, + String modelType) { + this.sampleRate = sampleRate; + EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); + EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); + EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); + OnlineTransducerModelConfig modelTranCfg = + new OnlineTransducerModelConfig(encoder, decoder, joiner); + OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); + OnlineModelConfig modelCfg = + new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig); + FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); + OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + onlineLmConfig, + enableEndpointDetection, + decodingMethod, + maxActivePaths, + hotwordsFile, + hotwordsScore); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(new Object(), rcgCfg); } - return proMap; - } - public void decodeStream(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - // when feeded samples to engine, call DecodeStream to let it process - decodeStream(this.ptr, streamPtr); - } + public static float[] readWavFile(String fileName) { + // read data from the filename + Object[] wavdata = readWave(fileName); + Object data = wavdata[0]; // data[0] is float data, data[1] sample rate - public void decodeStreams(OnlineStream[] ssOjb) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - // decode for multiple streams - long[] ss = new long[ssOjb.length]; - for (int i = 0; i < ssOjb.length; i++) { - ss[i] = ssOjb[i].getPtr(); - if (ss[i] == 0) throw new Exception("null exception for stream ptr"); + float[] floatData = (float[]) data; + + return floatData; } - decodeStreams(this.ptr, ss); - } - public boolean isReady(OnlineStream s) throws Exception { - // whether the engine is ready for decode - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return isReady(this.ptr, streamPtr); - } + // load the libsherpa-onnx-jni.so lib + public static void loadSoLib(String soPath) { + // load libsherpa-onnx-jni.so lib from the path - public String getResult(OnlineStream s) throws Exception { - // get text from the engine - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return getResult(this.ptr, streamPtr); - } + System.out.println("so lib path=" + soPath + "\n"); + System.load(soPath.trim()); + System.out.println("load so lib succeed\n"); + } - public boolean isEndpoint(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return isEndpoint(this.ptr, streamPtr); - } + public static void setSoPath(String soPath) { + OnlineRecognizer.loadSoLib(soPath); + OnlineStream.loadSoLib(soPath); + } - public void reSet(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - reSet(this.ptr, streamPtr); - } + private static native Object[] readWave(String fileName); // static - public OnlineStream createStream() throws Exception { - // create one stream for data to feed in - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = createStream(this.ptr); - OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate); - return stream; - } + private Map readProperties(String modelCfgPath) { + // read and parse config file + Properties props = new Properties(); + Map proMap = new HashMap<>(); + try { + File file = new File(modelCfgPath); + if (!file.exists()) { + System.out.println("model cfg file not exists!"); + System.exit(0); + } + InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath)); + props.load(in); + Enumeration en = props.propertyNames(); + while (en.hasMoreElements()) { + String key = (String) en.nextElement(); + String Property = props.getProperty(key); + proMap.put(key, Property); + } - public static float[] readWavFile(String fileName) { - // read data from the filename - Object[] wavdata = readWave(fileName); - Object data = wavdata[0]; // data[0] is float data, data[1] sample rate + } catch (Exception e) { + e.printStackTrace(); + } + return proMap; + } - float[] floatData = (float[]) data; + public void decodeStream(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + // when feeded samples to engine, call DecodeStream to let it process + decodeStream(this.ptr, streamPtr); + } - return floatData; - } + public void decodeStreams(OnlineStream[] ssOjb) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + // decode for multiple streams + long[] ss = new long[ssOjb.length]; + for (int i = 0; i < ssOjb.length; i++) { + ss[i] = ssOjb[i].getPtr(); + if (ss[i] == 0) throw new Exception("null exception for stream ptr"); + } + decodeStreams(this.ptr, ss); + } - // load the libsherpa-onnx-jni.so lib - public static void loadSoLib(String soPath) { - // load libsherpa-onnx-jni.so lib from the path + public boolean isReady(OnlineStream s) throws Exception { + // whether the engine is ready for decode + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return isReady(this.ptr, streamPtr); + } - System.out.println("so lib path=" + soPath + "\n"); - System.load(soPath.trim()); - System.out.println("load so lib succeed\n"); - } + public String getResult(OnlineStream s) throws Exception { + // get text from the engine + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return getResult(this.ptr, streamPtr); + } - public static void setSoPath(String soPath) { - OnlineRecognizer.loadSoLib(soPath); - OnlineStream.loadSoLib(soPath); - } + public boolean isEndpoint(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return isEndpoint(this.ptr, streamPtr); + } - protected void finalize() throws Throwable { - release(); - } + public void reSet(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + reSet(this.ptr, streamPtr); + } - // recognizer release, you'd better call it manually if not use anymore - public void release() { - if (this.ptr == 0) return; - deleteOnlineRecognizer(this.ptr); - this.ptr = 0; - } + public OnlineStream createStream() throws Exception { + // create one stream for data to feed in + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = createStream(this.ptr); + OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate); + return stream; + } - // stream release, you'd better call it manually if not use anymore - public void releaseStream(OnlineStream s) { - s.release(); - } + protected void finalize() throws Throwable { + release(); + } - // JNI interface libsherpa-onnx-jni.so + // recognizer release, you'd better call it manually if not use anymore + public void release() { + if (this.ptr == 0) return; + deleteOnlineRecognizer(this.ptr); + this.ptr = 0; + } - private static native Object[] readWave(String fileName); // static + // JNI interface libsherpa-onnx-jni.so - private native String getResult(long ptr, long streamPtr); + // stream release, you'd better call it manually if not use anymore + public void releaseStream(OnlineStream s) { + s.release(); + } - private native void decodeStream(long ptr, long streamPtr); + private native String getResult(long ptr, long streamPtr); - private native void decodeStreams(long ptr, long[] ssPtr); + private native void decodeStream(long ptr, long streamPtr); - private native boolean isReady(long ptr, long streamPtr); + private native void decodeStreams(long ptr, long[] ssPtr); - // first parameter keep for android asset_manager ANDROID_API__ >= 9 - private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config); + private native boolean isReady(long ptr, long streamPtr); - private native long createStream(long ptr); + // first parameter keep for android asset_manager ANDROID_API__ >= 9 + private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config); - private native void deleteOnlineRecognizer(long ptr); + private native long createStream(long ptr); - private native boolean isEndpoint(long ptr, long streamPtr); + private native void deleteOnlineRecognizer(long ptr); - private native void reSet(long ptr, long streamPtr); + private native boolean isEndpoint(long ptr, long streamPtr); + + private native void reSet(long ptr, long streamPtr); } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 0f1cdb81..74f035cb 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -5,62 +5,62 @@ package com.k2fsa.sherpa.onnx; public class OnlineRecognizerConfig { - private final FeatureConfig featConfig; - private final OnlineModelConfig modelConfig; - private final EndpointConfig endpointConfig; - private final OnlineLMConfig lmConfig; - private final boolean enableEndpoint; - private final String decodingMethod; - private final int maxActivePaths; - private final String hotwordsFile; - private final float hotwordsScore; + private final FeatureConfig featConfig; + private final OnlineModelConfig modelConfig; + private final EndpointConfig endpointConfig; + private final OnlineLMConfig lmConfig; + private final boolean enableEndpoint; + private final String decodingMethod; + private final int maxActivePaths; + private final String hotwordsFile; + private final float hotwordsScore; - public OnlineRecognizerConfig( - FeatureConfig featConfig, - OnlineModelConfig modelConfig, - EndpointConfig endpointConfig, - OnlineLMConfig lmConfig, - boolean enableEndpoint, - String decodingMethod, - int maxActivePaths, - String hotwordsFile, - float hotwordsScore) { - this.featConfig = featConfig; - this.modelConfig = modelConfig; - this.endpointConfig = endpointConfig; - this.lmConfig = lmConfig; - this.enableEndpoint = enableEndpoint; - this.decodingMethod = decodingMethod; - this.maxActivePaths = maxActivePaths; - this.hotwordsFile = hotwordsFile; - this.hotwordsScore = hotwordsScore; - } + public OnlineRecognizerConfig( + FeatureConfig featConfig, + OnlineModelConfig modelConfig, + EndpointConfig endpointConfig, + OnlineLMConfig lmConfig, + boolean enableEndpoint, + String decodingMethod, + int maxActivePaths, + String hotwordsFile, + float hotwordsScore) { + this.featConfig = featConfig; + this.modelConfig = modelConfig; + this.endpointConfig = endpointConfig; + this.lmConfig = lmConfig; + this.enableEndpoint = enableEndpoint; + this.decodingMethod = decodingMethod; + this.maxActivePaths = maxActivePaths; + this.hotwordsFile = hotwordsFile; + this.hotwordsScore = hotwordsScore; + } - public OnlineLMConfig getLmConfig() { - return lmConfig; - } + public OnlineLMConfig getLmConfig() { + return lmConfig; + } - public FeatureConfig getFeatConfig() { - return featConfig; - } + public FeatureConfig getFeatConfig() { + return featConfig; + } - public OnlineModelConfig getModelConfig() { - return modelConfig; - } + public OnlineModelConfig getModelConfig() { + return modelConfig; + } - public EndpointConfig getEndpointConfig() { - return endpointConfig; - } + public EndpointConfig getEndpointConfig() { + return endpointConfig; + } - public boolean isEnableEndpoint() { - return enableEndpoint; - } + public boolean isEnableEndpoint() { + return enableEndpoint; + } - public String getDecodingMethod() { - return decodingMethod; - } + public String getDecodingMethod() { + return decodingMethod; + } - public int getMaxActivePaths() { - return maxActivePaths; - } + public int getMaxActivePaths() { + return maxActivePaths; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java index 557b4d8d..42df0101 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java @@ -4,83 +4,81 @@ // Stream is used for feeding data to the asr engine package com.k2fsa.sherpa.onnx; -import java.io.*; -import java.util.*; - public class OnlineStream { - private long ptr = 0; // this is the stream ptr + private long ptr = 0; // this is the stream ptr - private int sampleRate = 16000; - // assign ptr to this stream in construction - public OnlineStream(long ptr, int sampleRate) { - this.ptr = ptr; - this.sampleRate = sampleRate; - } + private int sampleRate = 16000; - public long getPtr() { - return ptr; - } + // assign ptr to this stream in construction + public OnlineStream(long ptr, int sampleRate) { + this.ptr = ptr; + this.sampleRate = sampleRate; + } - public void acceptWaveform(float[] samples) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + public static void loadSoLib(String soPath) { + // load .so lib from the path + System.load(soPath.trim()); // ("sherpa-onnx-jni-java"); + } - // feed wave data to asr engine - acceptWaveform(this.ptr, this.sampleRate, samples); - } + public long getPtr() { + return ptr; + } - public void inputFinished() { - // add some tail padding - int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate - float tailPaddings[] = new float[padLen]; // default value is 0 - acceptWaveform(this.ptr, this.sampleRate, tailPaddings); + public void acceptWaveform(float[] samples) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - // tell the engine all data are feeded - inputFinished(this.ptr); - } + // feed wave data to asr engine + acceptWaveform(this.ptr, this.sampleRate, samples); + } - public static void loadSoLib(String soPath) { - // load .so lib from the path - System.load(soPath.trim()); // ("sherpa-onnx-jni-java"); - } + public void inputFinished() { + // add some tail padding + int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate + float[] tailPaddings = new float[padLen]; // default value is 0 + acceptWaveform(this.ptr, this.sampleRate, tailPaddings); - public void release() { - // stream object must be release after used - if (this.ptr == 0) return; - deleteStream(this.ptr); - this.ptr = 0; - } + // tell the engine all data are feeded + inputFinished(this.ptr); + } - protected void finalize() throws Throwable { - release(); - } + public void release() { + // stream object must be release after used + if (this.ptr == 0) return; + deleteStream(this.ptr); + this.ptr = 0; + } - public boolean isLastFrame() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - return isLastFrame(this.ptr); - } + protected void finalize() throws Throwable { + release(); + } - public void reSet() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - reSet(this.ptr); - } + public boolean isLastFrame() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + return isLastFrame(this.ptr); + } - public int featureDim() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - return featureDim(this.ptr); - } + public void reSet() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + reSet(this.ptr); + } - // JNI interface libsherpa-onnx-jni.so - private native void acceptWaveform(long ptr, int sampleRate, float[] samples); + public int featureDim() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + return featureDim(this.ptr); + } - private native void inputFinished(long ptr); + // JNI interface libsherpa-onnx-jni.so + private native void acceptWaveform(long ptr, int sampleRate, float[] samples); - private native void deleteStream(long ptr); + private native void inputFinished(long ptr); - private native int numFramesReady(long ptr); + private native void deleteStream(long ptr); - private native boolean isLastFrame(long ptr); + private native int numFramesReady(long ptr); - private native void reSet(long ptr); + private native boolean isLastFrame(long ptr); - private native int featureDim(long ptr); + private native void reSet(long ptr); + + private native int featureDim(long ptr); } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java index a5bc5300..6faf5f96 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java @@ -5,25 +5,25 @@ package com.k2fsa.sherpa.onnx; public class OnlineTransducerModelConfig { - private final String encoder; - private final String decoder; - private final String joiner; + private final String encoder; + private final String decoder; + private final String joiner; - public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) { - this.encoder = encoder; - this.decoder = decoder; - this.joiner = joiner; - } + public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) { + this.encoder = encoder; + this.decoder = decoder; + this.joiner = joiner; + } - public String getEncoder() { - return encoder; - } + public String getEncoder() { + return encoder; + } - public String getDecoder() { - return decoder; - } + public String getDecoder() { + return decoder; + } - public String getJoiner() { - return joiner; - } + public String getJoiner() { + return joiner; + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java new file mode 100644 index 00000000..07309b50 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java @@ -0,0 +1,14 @@ +package com.k2fsa.sherpa.onnx; + +public class OnlineZipformer2CtcModelConfig { + private final String model; + + public OnlineZipformer2CtcModelConfig(String model) { + this.model = model; + } + + public String getModel() { + return model; + } + +} diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index e52abc37..a8f0ef4a 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -1522,8 +1522,8 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete( SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( - JNIEnv *env, jobject /*obj*/, - jlong ptr, jboolean recreate, jstring keywords) { + JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate, + jstring keywords) { auto model = reinterpret_cast(ptr); const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); model->Reset(recreate, p_keywords);