update java for paraformer (#276)

This commit is contained in:
zhaomingwork
2023-08-16 20:16:51 +08:00
committed by GitHub
parent f709c95c5f
commit 256a8ecb50
6 changed files with 109 additions and 46 deletions

View File

@@ -218,14 +218,15 @@ public class AsrWebsocketServer extends WebSocketServer {
String soPath = args[0]; String soPath = args[0];
String cfgPath = args[1]; String cfgPath = args[1];
OnlineRecognizer.setSoPath(soPath); OnlineRecognizer.setSoPath(soPath);
logger.info("readProperties");
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
int port = Integer.valueOf(cfgMap.get("port")); int port = Integer.valueOf(cfgMap.get("port"));
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
logger.info("initModelWithCfg");
s.initModelWithCfg(cfgMap, cfgPath); s.initModelWithCfg(cfgMap, cfgPath);
logger.info("Server started on port: " + s.getPort()); logger.info("Server started on port: " + s.getPort());
s.start(); s.start();

View File

@@ -0,0 +1,51 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
package com.k2fsa.sherpa.onnx;
public class OnlineModelConfig {
private final OnlineParaformerModelConfig paraformer;
private final OnlineTransducerModelConfig transducer;
private final String tokens;
private final int numThreads;
private final boolean debug;
private final String provider = "cpu";
private String modelType = "";
public OnlineModelConfig(
String tokens,
int numThreads,
boolean debug,
String modelType,
OnlineParaformerModelConfig paraformer,
OnlineTransducerModelConfig transducer) {
this.tokens = tokens;
this.numThreads = numThreads;
this.debug = debug;
this.modelType = modelType;
this.paraformer = paraformer;
this.transducer = transducer;
}
public OnlineParaformerModelConfig getParaformer() {
return paraformer;
}
public OnlineTransducerModelConfig getTransducer() {
return transducer;
}
public String getTokens() {
return tokens;
}
public int getNumThreads() {
return numThreads;
}
public boolean getDebug() {
return debug;
}
}

View File

@@ -0,0 +1,23 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
package com.k2fsa.sherpa.onnx;
public class OnlineParaformerModelConfig {
private final String encoder;
private final String decoder;
public OnlineParaformerModelConfig(String encoder, String decoder) {
this.encoder = encoder;
this.decoder = decoder;
}
public String getEncoder() {
return encoder;
}
public String getDecoder() {
return decoder;
}
}

View File

@@ -56,15 +56,21 @@ public class OnlineRecognizer {
new EndpointRule( new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineTransducerModelConfig modelCfg =
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig("", "");
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig( new OnlineTransducerModelConfig(
proMap.get("encoder").trim(), proMap.get("encoder").trim(),
proMap.get("decoder").trim(), proMap.get("decoder").trim(),
proMap.get("joiner").trim(), proMap.get("joiner").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(), proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()), Integer.parseInt(proMap.get("num_threads").trim()),
false, false,
proMap.get("model_type").trim()); proMap.get("model_type").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig = FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
OnlineLMConfig onlineLmConfig = OnlineLMConfig onlineLmConfig =
@@ -104,15 +110,23 @@ public class OnlineRecognizer {
new EndpointRule( new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineTransducerModelConfig modelCfg = OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.get("encoder").trim(), proMap.get("decoder").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig( new OnlineTransducerModelConfig(
proMap.get("encoder").trim(), proMap.get("encoder").trim(),
proMap.get("decoder").trim(), proMap.get("decoder").trim(),
proMap.get("joiner").trim(), proMap.get("joiner").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(), proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()), Integer.parseInt(proMap.get("num_threads").trim()),
false, false,
proMap.get("model_type").trim()); proMap.get("model_type").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig = FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
@@ -160,9 +174,11 @@ public class OnlineRecognizer {
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineTransducerModelConfig modelCfg = OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
new OnlineTransducerModelConfig( OnlineTransducerModelConfig modelTranCfg =
encoder, decoder, joiner, tokens, numThreads, false, modelType); new OnlineTransducerModelConfig(encoder, decoder, joiner);
OnlineModelConfig modelCfg =
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg);
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
OnlineRecognizerConfig rcgCfg = OnlineRecognizerConfig rcgCfg =
@@ -277,6 +293,7 @@ public class OnlineRecognizer {
System.out.println("so lib path=" + soPath + "\n"); System.out.println("so lib path=" + soPath + "\n");
System.load(soPath.trim()); System.load(soPath.trim());
System.out.println("load so lib succeed\n");
} }
public static void setSoPath(String soPath) { public static void setSoPath(String soPath) {

View File

@@ -6,26 +6,25 @@ package com.k2fsa.sherpa.onnx;
public class OnlineRecognizerConfig { public class OnlineRecognizerConfig {
private final FeatureConfig featConfig; private final FeatureConfig featConfig;
private final OnlineTransducerModelConfig modelConfig; private final OnlineModelConfig modelConfig;
private final EndpointConfig endpointConfig; private final EndpointConfig endpointConfig;
private final OnlineLMConfig lmConfig; private final OnlineLMConfig lmConfig;
private final boolean enableEndpoint; private final boolean enableEndpoint;
private final String decodingMethod; private final String decodingMethod;
private final int maxActivePaths; private final int maxActivePaths;
public OnlineRecognizerConfig( public OnlineRecognizerConfig(
FeatureConfig featConfig, FeatureConfig featConfig,
OnlineTransducerModelConfig modelConfig, OnlineModelConfig modelConfig,
EndpointConfig endpointConfig, EndpointConfig endpointConfig,
OnlineLMConfig lmConfig, OnlineLMConfig lmConfig,
boolean enableEndpoint, boolean enableEndpoint,
String decodingMethod, String decodingMethod,
int maxActivePaths) { int maxActivePaths) {
this.featConfig = featConfig; this.featConfig = featConfig;
this.modelConfig = modelConfig; this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig; this.endpointConfig = endpointConfig;
this.lmConfig = lmConfig; this.lmConfig = lmConfig;
this.enableEndpoint = enableEndpoint; this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod; this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths; this.maxActivePaths = maxActivePaths;
@@ -39,7 +38,7 @@ public class OnlineRecognizerConfig {
return featConfig; return featConfig;
} }
public OnlineTransducerModelConfig getModelConfig() { public OnlineModelConfig getModelConfig() {
return modelConfig; return modelConfig;
} }

View File

@@ -8,27 +8,11 @@ public class OnlineTransducerModelConfig {
private final String encoder; private final String encoder;
private final String decoder; private final String decoder;
private final String joiner; private final String joiner;
private final String tokens;
private final int numThreads;
private final boolean debug;
private final String provider = "cpu";
private String modelType = "";
public OnlineTransducerModelConfig( public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
String encoder,
String decoder,
String joiner,
String tokens,
int numThreads,
boolean debug,
String modelType) {
this.encoder = encoder; this.encoder = encoder;
this.decoder = decoder; this.decoder = decoder;
this.joiner = joiner; this.joiner = joiner;
this.tokens = tokens;
this.numThreads = numThreads;
this.debug = debug;
this.modelType = modelType;
} }
public String getEncoder() { public String getEncoder() {
@@ -42,16 +26,4 @@ public class OnlineTransducerModelConfig {
public String getJoiner() { public String getJoiner() {
return joiner; return joiner;
} }
public String getTokens() {
return tokens;
}
public int getNumThreads() {
return numThreads;
}
public boolean getDebug() {
return debug;
}
} }