diff --git a/java-api-examples/src/websocketsrv/AsrWebsocketServer.java b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java index 879e9a9f..17421e14 100755 --- a/java-api-examples/src/websocketsrv/AsrWebsocketServer.java +++ b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java @@ -218,14 +218,15 @@ public class AsrWebsocketServer extends WebSocketServer { String soPath = args[0]; String cfgPath = args[1]; - + OnlineRecognizer.setSoPath(soPath); - + logger.info("readProperties"); Map cfgMap = AsrWebsocketServer.readProperties(cfgPath); int port = Integer.valueOf(cfgMap.get("port")); int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); + logger.info("initModelWithCfg"); s.initModelWithCfg(cfgMap, cfgPath); logger.info("Server started on port: " + s.getPort()); s.start(); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java new file mode 100644 index 00000000..42e0a99e --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java @@ -0,0 +1,51 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class OnlineModelConfig { + private final OnlineParaformerModelConfig paraformer; + private final OnlineTransducerModelConfig transducer; + private final String tokens; + private final int numThreads; + private final boolean debug; + private final String provider = "cpu"; + private String modelType = ""; + + public OnlineModelConfig( + String tokens, + int numThreads, + boolean debug, + String modelType, + OnlineParaformerModelConfig paraformer, + OnlineTransducerModelConfig transducer) { + + this.tokens = tokens; + this.numThreads = numThreads; + this.debug = debug; + this.modelType = modelType; + this.paraformer = paraformer; + this.transducer = transducer; + } + + public OnlineParaformerModelConfig getParaformer() { + return paraformer; + } + + public OnlineTransducerModelConfig getTransducer() { + return transducer; + } + + public String getTokens() { + return tokens; + } + + public int getNumThreads() { + return numThreads; + } + + public boolean getDebug() { + return debug; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java new file mode 100644 index 00000000..c7643f6e --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java @@ -0,0 +1,23 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class OnlineParaformerModelConfig { + private final String encoder; + private final String decoder; + + public OnlineParaformerModelConfig(String encoder, String decoder) { + this.encoder = encoder; + this.decoder = decoder; + } + + public String getEncoder() { + return encoder; + } + + public String getDecoder() { + return decoder; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 2116320f..85103516 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -56,15 +56,21 @@ public class OnlineRecognizer { new EndpointRule( false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineTransducerModelConfig modelCfg = + + OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig("", ""); + OnlineTransducerModelConfig modelTranCfg = new OnlineTransducerModelConfig( proMap.get("encoder").trim(), proMap.get("decoder").trim(), - proMap.get("joiner").trim(), + proMap.get("joiner").trim()); + OnlineModelConfig modelCfg = + new OnlineModelConfig( proMap.get("tokens").trim(), Integer.parseInt(proMap.get("num_threads").trim()), false, - proMap.get("model_type").trim()); + proMap.get("model_type").trim(), + modelParaCfg, + modelTranCfg); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); OnlineLMConfig onlineLmConfig = @@ -104,15 +110,23 @@ public class OnlineRecognizer { new EndpointRule( false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineTransducerModelConfig modelCfg = + OnlineParaformerModelConfig modelParaCfg = + new OnlineParaformerModelConfig( + proMap.get("encoder").trim(), proMap.get("decoder").trim()); + OnlineTransducerModelConfig modelTranCfg = new OnlineTransducerModelConfig( proMap.get("encoder").trim(), proMap.get("decoder").trim(), - proMap.get("joiner").trim(), + proMap.get("joiner").trim()); + + OnlineModelConfig modelCfg = + new OnlineModelConfig( proMap.get("tokens").trim(), Integer.parseInt(proMap.get("num_threads").trim()), false, - proMap.get("model_type").trim()); + proMap.get("model_type").trim(), + modelParaCfg, + modelTranCfg); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); @@ -160,9 +174,11 @@ public class OnlineRecognizer { EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineTransducerModelConfig modelCfg = - new OnlineTransducerModelConfig( - encoder, decoder, joiner, tokens, numThreads, false, modelType); + OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); + OnlineTransducerModelConfig modelTranCfg = + new OnlineTransducerModelConfig(encoder, decoder, joiner); + OnlineModelConfig modelCfg = + new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); OnlineRecognizerConfig rcgCfg = @@ -277,6 +293,7 @@ public class OnlineRecognizer { System.out.println("so lib path=" + soPath + "\n"); System.load(soPath.trim()); + System.out.println("load so lib succeed\n"); } public static void setSoPath(String soPath) { diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 2b96cb83..4462a708 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -6,26 +6,25 @@ package com.k2fsa.sherpa.onnx; public class OnlineRecognizerConfig { private final FeatureConfig featConfig; - private final OnlineTransducerModelConfig modelConfig; + private final OnlineModelConfig modelConfig; private final EndpointConfig endpointConfig; private final OnlineLMConfig lmConfig; private final boolean enableEndpoint; private final String decodingMethod; private final int maxActivePaths; - public OnlineRecognizerConfig( FeatureConfig featConfig, - OnlineTransducerModelConfig modelConfig, + OnlineModelConfig modelConfig, EndpointConfig endpointConfig, - OnlineLMConfig lmConfig, + OnlineLMConfig lmConfig, boolean enableEndpoint, String decodingMethod, int maxActivePaths) { this.featConfig = featConfig; this.modelConfig = modelConfig; this.endpointConfig = endpointConfig; - this.lmConfig = lmConfig; + this.lmConfig = lmConfig; this.enableEndpoint = enableEndpoint; this.decodingMethod = decodingMethod; this.maxActivePaths = maxActivePaths; @@ -39,7 +38,7 @@ public class OnlineRecognizerConfig { return featConfig; } - public OnlineTransducerModelConfig getModelConfig() { + public OnlineModelConfig getModelConfig() { return modelConfig; } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java index 7697bc51..a5bc5300 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java @@ -8,27 +8,11 @@ public class OnlineTransducerModelConfig { private final String encoder; private final String decoder; private final String joiner; - private final String tokens; - private final int numThreads; - private final boolean debug; - private final String provider = "cpu"; - private String modelType = ""; - public OnlineTransducerModelConfig( - String encoder, - String decoder, - String joiner, - String tokens, - int numThreads, - boolean debug, - String modelType) { + public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) { this.encoder = encoder; this.decoder = decoder; this.joiner = joiner; - this.tokens = tokens; - this.numThreads = numThreads; - this.debug = debug; - this.modelType = modelType; } public String getEncoder() { @@ -42,16 +26,4 @@ public class OnlineTransducerModelConfig { public String getJoiner() { return joiner; } - - public String getTokens() { - return tokens; - } - - public int getNumThreads() { - return numThreads; - } - - public boolean getDebug() { - return debug; - } }