Add java api for hotwords (#319)
* Add java api * support websocket * Fix kotlin
This commit is contained in:
@@ -44,38 +44,48 @@ public class OnlineRecognizer {
|
||||
public OnlineRecognizer(String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
|
||||
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim());
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.get("encoder").trim(),
|
||||
proMap.get("decoder").trim(),
|
||||
proMap.get("joiner").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.get("model_type").trim(),
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
@@ -83,9 +93,11 @@ public class OnlineRecognizer {
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
|
||||
@@ -98,41 +110,49 @@ public class OnlineRecognizer {
|
||||
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.get("encoder").trim(), proMap.get("decoder").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.get("encoder").trim(),
|
||||
proMap.get("decoder").trim(),
|
||||
proMap.get("joiner").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.get("model_type").trim(),
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
@@ -140,9 +160,11 @@ public class OnlineRecognizer {
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
|
||||
|
||||
@@ -168,6 +190,8 @@ public class OnlineRecognizer {
|
||||
String lm_model,
|
||||
float lm_scale,
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore,
|
||||
String modelType) {
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||
@@ -189,7 +213,9 @@ public class OnlineRecognizer {
|
||||
onlineLmConfig,
|
||||
enableEndpointDetection,
|
||||
decodingMethod,
|
||||
maxActivePaths);
|
||||
maxActivePaths,
|
||||
hotwordsFile,
|
||||
hotwordsScore);
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
}
|
||||
@@ -211,7 +237,6 @@ public class OnlineRecognizer {
|
||||
String key = (String) en.nextElement();
|
||||
String Property = props.getProperty(key);
|
||||
proMap.put(key, Property);
|
||||
// System.out.println(key+"="+Property);
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
|
||||
private final boolean enableEndpoint;
|
||||
private final String decodingMethod;
|
||||
private final int maxActivePaths;
|
||||
private final String hotwordsFile;
|
||||
private final float hotwordsScore;
|
||||
|
||||
public OnlineRecognizerConfig(
|
||||
FeatureConfig featConfig,
|
||||
@@ -20,7 +22,9 @@ public class OnlineRecognizerConfig {
|
||||
OnlineLMConfig lmConfig,
|
||||
boolean enableEndpoint,
|
||||
String decodingMethod,
|
||||
int maxActivePaths) {
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore) {
|
||||
this.featConfig = featConfig;
|
||||
this.modelConfig = modelConfig;
|
||||
this.endpointConfig = endpointConfig;
|
||||
@@ -28,6 +32,8 @@ public class OnlineRecognizerConfig {
|
||||
this.enableEndpoint = enableEndpoint;
|
||||
this.decodingMethod = decodingMethod;
|
||||
this.maxActivePaths = maxActivePaths;
|
||||
this.hotwordsFile = hotwordsFile;
|
||||
this.hotwordsScore = hotwordsScore;
|
||||
}
|
||||
|
||||
public OnlineLMConfig getLmConfig() {
|
||||
|
||||
Reference in New Issue
Block a user