update java for paraformer (#276)
This commit is contained in:
@@ -218,14 +218,15 @@ public class AsrWebsocketServer extends WebSocketServer {
|
|||||||
|
|
||||||
String soPath = args[0];
|
String soPath = args[0];
|
||||||
String cfgPath = args[1];
|
String cfgPath = args[1];
|
||||||
|
|
||||||
OnlineRecognizer.setSoPath(soPath);
|
OnlineRecognizer.setSoPath(soPath);
|
||||||
|
logger.info("readProperties");
|
||||||
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
|
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
|
||||||
int port = Integer.valueOf(cfgMap.get("port"));
|
int port = Integer.valueOf(cfgMap.get("port"));
|
||||||
|
|
||||||
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
|
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
|
||||||
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
|
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
|
||||||
|
logger.info("initModelWithCfg");
|
||||||
s.initModelWithCfg(cfgMap, cfgPath);
|
s.initModelWithCfg(cfgMap, cfgPath);
|
||||||
logger.info("Server started on port: " + s.getPort());
|
logger.info("Server started on port: " + s.getPort());
|
||||||
s.start();
|
s.start();
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class OnlineModelConfig {
|
||||||
|
private final OnlineParaformerModelConfig paraformer;
|
||||||
|
private final OnlineTransducerModelConfig transducer;
|
||||||
|
private final String tokens;
|
||||||
|
private final int numThreads;
|
||||||
|
private final boolean debug;
|
||||||
|
private final String provider = "cpu";
|
||||||
|
private String modelType = "";
|
||||||
|
|
||||||
|
public OnlineModelConfig(
|
||||||
|
String tokens,
|
||||||
|
int numThreads,
|
||||||
|
boolean debug,
|
||||||
|
String modelType,
|
||||||
|
OnlineParaformerModelConfig paraformer,
|
||||||
|
OnlineTransducerModelConfig transducer) {
|
||||||
|
|
||||||
|
this.tokens = tokens;
|
||||||
|
this.numThreads = numThreads;
|
||||||
|
this.debug = debug;
|
||||||
|
this.modelType = modelType;
|
||||||
|
this.paraformer = paraformer;
|
||||||
|
this.transducer = transducer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnlineParaformerModelConfig getParaformer() {
|
||||||
|
return paraformer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnlineTransducerModelConfig getTransducer() {
|
||||||
|
return transducer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTokens() {
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumThreads() {
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean getDebug() {
|
||||||
|
return debug;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class OnlineParaformerModelConfig {
|
||||||
|
private final String encoder;
|
||||||
|
private final String decoder;
|
||||||
|
|
||||||
|
public OnlineParaformerModelConfig(String encoder, String decoder) {
|
||||||
|
this.encoder = encoder;
|
||||||
|
this.decoder = decoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getEncoder() {
|
||||||
|
return encoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getDecoder() {
|
||||||
|
return decoder;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,15 +56,21 @@ public class OnlineRecognizer {
|
|||||||
new EndpointRule(
|
new EndpointRule(
|
||||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||||
OnlineTransducerModelConfig modelCfg =
|
|
||||||
|
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig("", "");
|
||||||
|
OnlineTransducerModelConfig modelTranCfg =
|
||||||
new OnlineTransducerModelConfig(
|
new OnlineTransducerModelConfig(
|
||||||
proMap.get("encoder").trim(),
|
proMap.get("encoder").trim(),
|
||||||
proMap.get("decoder").trim(),
|
proMap.get("decoder").trim(),
|
||||||
proMap.get("joiner").trim(),
|
proMap.get("joiner").trim());
|
||||||
|
OnlineModelConfig modelCfg =
|
||||||
|
new OnlineModelConfig(
|
||||||
proMap.get("tokens").trim(),
|
proMap.get("tokens").trim(),
|
||||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||||
false,
|
false,
|
||||||
proMap.get("model_type").trim());
|
proMap.get("model_type").trim(),
|
||||||
|
modelParaCfg,
|
||||||
|
modelTranCfg);
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
OnlineLMConfig onlineLmConfig =
|
OnlineLMConfig onlineLmConfig =
|
||||||
@@ -104,15 +110,23 @@ public class OnlineRecognizer {
|
|||||||
new EndpointRule(
|
new EndpointRule(
|
||||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||||
OnlineTransducerModelConfig modelCfg =
|
OnlineParaformerModelConfig modelParaCfg =
|
||||||
|
new OnlineParaformerModelConfig(
|
||||||
|
proMap.get("encoder").trim(), proMap.get("decoder").trim());
|
||||||
|
OnlineTransducerModelConfig modelTranCfg =
|
||||||
new OnlineTransducerModelConfig(
|
new OnlineTransducerModelConfig(
|
||||||
proMap.get("encoder").trim(),
|
proMap.get("encoder").trim(),
|
||||||
proMap.get("decoder").trim(),
|
proMap.get("decoder").trim(),
|
||||||
proMap.get("joiner").trim(),
|
proMap.get("joiner").trim());
|
||||||
|
|
||||||
|
OnlineModelConfig modelCfg =
|
||||||
|
new OnlineModelConfig(
|
||||||
proMap.get("tokens").trim(),
|
proMap.get("tokens").trim(),
|
||||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||||
false,
|
false,
|
||||||
proMap.get("model_type").trim());
|
proMap.get("model_type").trim(),
|
||||||
|
modelParaCfg,
|
||||||
|
modelTranCfg);
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
|
|
||||||
@@ -160,9 +174,11 @@ public class OnlineRecognizer {
|
|||||||
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
|
||||||
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
|
||||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||||
OnlineTransducerModelConfig modelCfg =
|
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
|
||||||
new OnlineTransducerModelConfig(
|
OnlineTransducerModelConfig modelTranCfg =
|
||||||
encoder, decoder, joiner, tokens, numThreads, false, modelType);
|
new OnlineTransducerModelConfig(encoder, decoder, joiner);
|
||||||
|
OnlineModelConfig modelCfg =
|
||||||
|
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg);
|
||||||
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
||||||
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
|
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
|
||||||
OnlineRecognizerConfig rcgCfg =
|
OnlineRecognizerConfig rcgCfg =
|
||||||
@@ -277,6 +293,7 @@ public class OnlineRecognizer {
|
|||||||
|
|
||||||
System.out.println("so lib path=" + soPath + "\n");
|
System.out.println("so lib path=" + soPath + "\n");
|
||||||
System.load(soPath.trim());
|
System.load(soPath.trim());
|
||||||
|
System.out.println("load so lib succeed\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void setSoPath(String soPath) {
|
public static void setSoPath(String soPath) {
|
||||||
|
|||||||
@@ -6,26 +6,25 @@ package com.k2fsa.sherpa.onnx;
|
|||||||
|
|
||||||
public class OnlineRecognizerConfig {
|
public class OnlineRecognizerConfig {
|
||||||
private final FeatureConfig featConfig;
|
private final FeatureConfig featConfig;
|
||||||
private final OnlineTransducerModelConfig modelConfig;
|
private final OnlineModelConfig modelConfig;
|
||||||
private final EndpointConfig endpointConfig;
|
private final EndpointConfig endpointConfig;
|
||||||
private final OnlineLMConfig lmConfig;
|
private final OnlineLMConfig lmConfig;
|
||||||
private final boolean enableEndpoint;
|
private final boolean enableEndpoint;
|
||||||
private final String decodingMethod;
|
private final String decodingMethod;
|
||||||
private final int maxActivePaths;
|
private final int maxActivePaths;
|
||||||
|
|
||||||
|
|
||||||
public OnlineRecognizerConfig(
|
public OnlineRecognizerConfig(
|
||||||
FeatureConfig featConfig,
|
FeatureConfig featConfig,
|
||||||
OnlineTransducerModelConfig modelConfig,
|
OnlineModelConfig modelConfig,
|
||||||
EndpointConfig endpointConfig,
|
EndpointConfig endpointConfig,
|
||||||
OnlineLMConfig lmConfig,
|
OnlineLMConfig lmConfig,
|
||||||
boolean enableEndpoint,
|
boolean enableEndpoint,
|
||||||
String decodingMethod,
|
String decodingMethod,
|
||||||
int maxActivePaths) {
|
int maxActivePaths) {
|
||||||
this.featConfig = featConfig;
|
this.featConfig = featConfig;
|
||||||
this.modelConfig = modelConfig;
|
this.modelConfig = modelConfig;
|
||||||
this.endpointConfig = endpointConfig;
|
this.endpointConfig = endpointConfig;
|
||||||
this.lmConfig = lmConfig;
|
this.lmConfig = lmConfig;
|
||||||
this.enableEndpoint = enableEndpoint;
|
this.enableEndpoint = enableEndpoint;
|
||||||
this.decodingMethod = decodingMethod;
|
this.decodingMethod = decodingMethod;
|
||||||
this.maxActivePaths = maxActivePaths;
|
this.maxActivePaths = maxActivePaths;
|
||||||
@@ -39,7 +38,7 @@ public class OnlineRecognizerConfig {
|
|||||||
return featConfig;
|
return featConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OnlineTransducerModelConfig getModelConfig() {
|
public OnlineModelConfig getModelConfig() {
|
||||||
return modelConfig;
|
return modelConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,27 +8,11 @@ public class OnlineTransducerModelConfig {
|
|||||||
private final String encoder;
|
private final String encoder;
|
||||||
private final String decoder;
|
private final String decoder;
|
||||||
private final String joiner;
|
private final String joiner;
|
||||||
private final String tokens;
|
|
||||||
private final int numThreads;
|
|
||||||
private final boolean debug;
|
|
||||||
private final String provider = "cpu";
|
|
||||||
private String modelType = "";
|
|
||||||
|
|
||||||
public OnlineTransducerModelConfig(
|
public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
|
||||||
String encoder,
|
|
||||||
String decoder,
|
|
||||||
String joiner,
|
|
||||||
String tokens,
|
|
||||||
int numThreads,
|
|
||||||
boolean debug,
|
|
||||||
String modelType) {
|
|
||||||
this.encoder = encoder;
|
this.encoder = encoder;
|
||||||
this.decoder = decoder;
|
this.decoder = decoder;
|
||||||
this.joiner = joiner;
|
this.joiner = joiner;
|
||||||
this.tokens = tokens;
|
|
||||||
this.numThreads = numThreads;
|
|
||||||
this.debug = debug;
|
|
||||||
this.modelType = modelType;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getEncoder() {
|
public String getEncoder() {
|
||||||
@@ -42,16 +26,4 @@ public class OnlineTransducerModelConfig {
|
|||||||
public String getJoiner() {
|
public String getJoiner() {
|
||||||
return joiner;
|
return joiner;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getTokens() {
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumThreads() {
|
|
||||||
return numThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean getDebug() {
|
|
||||||
return debug;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user