Add Kotlin and Java API for homophone replacer (#2166)
* Add Kotlin API for homonphone replacer * Add Java API for homonphone replacer
This commit is contained in:
@@ -11,6 +11,7 @@ java_files += WaveWriter.java
|
||||
java_files += EndpointRule.java
|
||||
java_files += EndpointConfig.java
|
||||
java_files += FeatureConfig.java
|
||||
java_files += HomophoneReplacerConfig.java
|
||||
java_files += OnlineLMConfig.java
|
||||
java_files += OnlineParaformerModelConfig.java
|
||||
java_files += OnlineZipformer2CtcModelConfig.java
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright 2025 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class HomophoneReplacerConfig {
|
||||
private final String dictDir;
|
||||
private final String lexicon;
|
||||
private final String ruleFsts;
|
||||
|
||||
private HomophoneReplacerConfig(Builder builder) {
|
||||
this.dictDir = builder.dictDir;
|
||||
this.lexicon = builder.lexicon;
|
||||
this.ruleFsts = builder.ruleFsts;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getDictDir() {
|
||||
return dictDir;
|
||||
}
|
||||
|
||||
public String getLexicon() {
|
||||
return lexicon;
|
||||
}
|
||||
|
||||
public String getRuleFsts() {
|
||||
return ruleFsts;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String dictDir = "";
|
||||
private String lexicon = "";
|
||||
private String ruleFsts = "";
|
||||
|
||||
public HomophoneReplacerConfig build() {
|
||||
return new HomophoneReplacerConfig(this);
|
||||
}
|
||||
|
||||
public Builder setDictDir(String dictDir) {
|
||||
this.dictDir = dictDir;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLexicon(String lexicon) {
|
||||
this.lexicon = lexicon;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setRuleFsts(String ruleFsts) {
|
||||
this.ruleFsts = ruleFsts;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx;
|
||||
public class OfflineRecognizerConfig {
|
||||
private final FeatureConfig featConfig;
|
||||
private final OfflineModelConfig modelConfig;
|
||||
private final HomophoneReplacerConfig hr;
|
||||
private final String decodingMethod;
|
||||
private final int maxActivePaths;
|
||||
private final String hotwordsFile;
|
||||
@@ -16,6 +17,7 @@ public class OfflineRecognizerConfig {
|
||||
private OfflineRecognizerConfig(Builder builder) {
|
||||
this.featConfig = builder.featConfig;
|
||||
this.modelConfig = builder.modelConfig;
|
||||
this.hr = builder.hr;
|
||||
this.decodingMethod = builder.decodingMethod;
|
||||
this.maxActivePaths = builder.maxActivePaths;
|
||||
this.hotwordsFile = builder.hotwordsFile;
|
||||
@@ -36,6 +38,7 @@ public class OfflineRecognizerConfig {
|
||||
public static class Builder {
|
||||
private FeatureConfig featConfig = FeatureConfig.builder().build();
|
||||
private OfflineModelConfig modelConfig = OfflineModelConfig.builder().build();
|
||||
private HomophoneReplacerConfig hr = HomophoneReplacerConfig.builder().build();
|
||||
private String decodingMethod = "greedy_search";
|
||||
private int maxActivePaths = 4;
|
||||
private String hotwordsFile = "";
|
||||
@@ -58,6 +61,11 @@ public class OfflineRecognizerConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHr(HomophoneReplacerConfig hr) {
|
||||
this.hr = hr;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDecodingMethod(String decodingMethod) {
|
||||
this.decodingMethod = decodingMethod;
|
||||
return this;
|
||||
|
||||
@@ -10,6 +10,7 @@ public class OnlineRecognizerConfig {
|
||||
|
||||
private final OnlineCtcFstDecoderConfig ctcFstDecoderConfig;
|
||||
private final EndpointConfig endpointConfig;
|
||||
private final HomophoneReplacerConfig hr;
|
||||
private final boolean enableEndpoint;
|
||||
private final String decodingMethod;
|
||||
private final int maxActivePaths;
|
||||
@@ -25,6 +26,7 @@ public class OnlineRecognizerConfig {
|
||||
this.lmConfig = builder.lmConfig;
|
||||
this.ctcFstDecoderConfig = builder.ctcFstDecoderConfig;
|
||||
this.endpointConfig = builder.endpointConfig;
|
||||
this.hr = builder.hr;
|
||||
this.enableEndpoint = builder.enableEndpoint;
|
||||
this.decodingMethod = builder.decodingMethod;
|
||||
this.maxActivePaths = builder.maxActivePaths;
|
||||
@@ -49,6 +51,7 @@ public class OnlineRecognizerConfig {
|
||||
private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build();
|
||||
private OnlineCtcFstDecoderConfig ctcFstDecoderConfig = OnlineCtcFstDecoderConfig.builder().build();
|
||||
private EndpointConfig endpointConfig = EndpointConfig.builder().build();
|
||||
private HomophoneReplacerConfig hr = HomophoneReplacerConfig.builder().build();
|
||||
private boolean enableEndpoint = true;
|
||||
private String decodingMethod = "greedy_search";
|
||||
private int maxActivePaths = 4;
|
||||
@@ -87,6 +90,11 @@ public class OnlineRecognizerConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHr(HomophoneReplacerConfig hr) {
|
||||
this.hr = hr;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEnableEndpoint(boolean enableEndpoint) {
|
||||
this.enableEndpoint = enableEndpoint;
|
||||
return this;
|
||||
|
||||
@@ -284,6 +284,30 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
ans.model_config.telespeech_ctc = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// homophone replacer config
|
||||
fid = env->GetFieldID(cls, "hr",
|
||||
"Lcom/k2fsa/sherpa/onnx/HomophoneReplacerConfig;");
|
||||
jobject hr_config = env->GetObjectField(config, fid);
|
||||
jclass hr_config_cls = env->GetObjectClass(hr_config);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "dictDir", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.dict_dir = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "lexicon", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.lexicon = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "ruleFsts", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.rule_fsts = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
@@ -253,6 +253,30 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
ans.ctc_fst_decoder_config.max_active =
|
||||
env->GetIntField(fst_decoder_config, fid);
|
||||
|
||||
// homophone replacer config
|
||||
fid = env->GetFieldID(cls, "hr",
|
||||
"Lcom/k2fsa/sherpa/onnx/HomophoneReplacerConfig;");
|
||||
jobject hr_config = env->GetObjectField(config, fid);
|
||||
jclass hr_config_cls = env->GetObjectClass(hr_config);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "dictDir", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.dict_dir = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "lexicon", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.lexicon = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(hr_config_cls, "ruleFsts", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(hr_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hr.rule_fsts = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
7
sherpa-onnx/kotlin-api/HomophoneReplacerConfig.kt
Normal file
7
sherpa-onnx/kotlin-api/HomophoneReplacerConfig.kt
Normal file
@@ -0,0 +1,7 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
data class HomophoneReplacerConfig(
|
||||
var dictDir: String = "",
|
||||
var lexicon: String = "",
|
||||
var ruleFsts: String = "",
|
||||
)
|
||||
@@ -78,6 +78,7 @@ data class OfflineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OfflineModelConfig = OfflineModelConfig(),
|
||||
// var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
|
||||
var hr: HomophoneReplacerConfig = HomophoneReplacerConfig(),
|
||||
var decodingMethod: String = "greedy_search",
|
||||
var maxActivePaths: Int = 4,
|
||||
var hotwordsFile: String = "",
|
||||
|
||||
@@ -57,12 +57,12 @@ data class OnlineCtcFstDecoderConfig(
|
||||
var maxActive: Int = 3000,
|
||||
)
|
||||
|
||||
|
||||
data class OnlineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineModelConfig = OnlineModelConfig(),
|
||||
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
|
||||
var ctcFstDecoderConfig: OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
|
||||
var hr: HomophoneReplacerConfig = HomophoneReplacerConfig(),
|
||||
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||
var enableEndpoint: Boolean = true,
|
||||
var decodingMethod: String = "greedy_search",
|
||||
|
||||
Reference in New Issue
Block a user