Add java api for hotwords (#319)

* Add java api

* support websocket

* Fix kotlin
This commit is contained in:
Wei Kang
2023-09-18 22:44:29 +08:00
committed by GitHub
parent 4dfc11066a
commit d7eab95439
9 changed files with 117 additions and 51 deletions

View File

@@ -44,38 +44,48 @@ public class OnlineRecognizer {
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim());
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
@@ -83,9 +93,11 @@ public class OnlineRecognizer {
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
@@ -98,41 +110,49 @@ public class OnlineRecognizer {
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.get("encoder").trim(), proMap.get("decoder").trim());
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
@@ -140,9 +160,11 @@ public class OnlineRecognizer {
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
@@ -168,6 +190,8 @@ public class OnlineRecognizer {
String lm_model,
float lm_scale,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
@@ -189,7 +213,9 @@ public class OnlineRecognizer {
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths);
maxActivePaths,
hotwordsFile,
hotwordsScore);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
@@ -211,7 +237,6 @@ public class OnlineRecognizer {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
// System.out.println(key+"="+Property);
}
} catch (Exception e) {

View File

@@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
private final boolean enableEndpoint;
private final String decodingMethod;
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;
public OnlineRecognizerConfig(
FeatureConfig featConfig,
@@ -20,7 +22,9 @@ public class OnlineRecognizerConfig {
OnlineLMConfig lmConfig,
boolean enableEndpoint,
String decodingMethod,
int maxActivePaths) {
int maxActivePaths,
String hotwordsFile,
float hotwordsScore) {
this.featConfig = featConfig;
this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig;
@@ -28,6 +32,8 @@ public class OnlineRecognizerConfig {
this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore;
}
public OnlineLMConfig getLmConfig() {