Add blank penalty for various language bindings. (#1234)

This commit is contained in:
Fangjun Kuang
2024-08-08 10:43:31 +08:00
committed by GitHub
parent ba4cb6169f
commit 94e256244d
38 changed files with 123 additions and 42 deletions

View File

@@ -105,7 +105,7 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
recognizer_config.blank_penalty = SHERPA_ONNX_OR(config->blank_penalty, 0.0);
recognizer_config.blank_penalty = config->blank_penalty;
recognizer_config.ctc_fst_decoder_config.graph =
SHERPA_ONNX_OR(config->ctc_fst_decoder_config.graph, "");
@@ -429,6 +429,8 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
recognizer_config.blank_penalty = config->blank_penalty;
recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, "");
recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, "");

View File

@@ -142,11 +142,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
/// Bonus score for each token in hotwords.
float hotwords_score;
float blank_penalty;
SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config;
const char *rule_fsts;
const char *rule_fars;
float blank_penalty;
} SherpaOnnxOnlineRecognizerConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
@@ -430,6 +430,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
float hotwords_score;
const char *rule_fsts;
const char *rule_fars;
float blank_penalty;
} SherpaOnnxOfflineRecognizerConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizer

View File

@@ -11,6 +11,7 @@ public class OfflineRecognizerConfig {
private final float hotwordsScore;
private final String ruleFsts;
private final String ruleFars;
private final float blankPenalty;
private OfflineRecognizerConfig(Builder builder) {
this.featConfig = builder.featConfig;
@@ -21,6 +22,7 @@ public class OfflineRecognizerConfig {
this.hotwordsScore = builder.hotwordsScore;
this.ruleFsts = builder.ruleFsts;
this.ruleFars = builder.ruleFars;
this.blankPenalty = builder.blankPenalty;
}
public static Builder builder() {
@@ -40,6 +42,7 @@ public class OfflineRecognizerConfig {
private float hotwordsScore = 1.5f;
private String ruleFsts = "";
private String ruleFars = "";
private float blankPenalty = 0.0f;
public OfflineRecognizerConfig build() {
return new OfflineRecognizerConfig(this);
@@ -84,5 +87,10 @@ public class OfflineRecognizerConfig {
this.ruleFars = ruleFars;
return this;
}
public Builder setBlankPenalty(float blankPenalty) {
this.blankPenalty = blankPenalty;
return this;
}
}
}

View File

@@ -17,6 +17,7 @@ public class OnlineRecognizerConfig {
private final float hotwordsScore;
private final String ruleFsts;
private final String ruleFars;
private final float blankPenalty;
private OnlineRecognizerConfig(Builder builder) {
this.featConfig = builder.featConfig;
@@ -31,6 +32,7 @@ public class OnlineRecognizerConfig {
this.hotwordsScore = builder.hotwordsScore;
this.ruleFsts = builder.ruleFsts;
this.ruleFars = builder.ruleFars;
this.blankPenalty = builder.blankPenalty;
}
public static Builder builder() {
@@ -54,6 +56,7 @@ public class OnlineRecognizerConfig {
private float hotwordsScore = 1.5f;
private String ruleFsts = "";
private String ruleFars = "";
private float blankPenalty = 0.0f;
public OnlineRecognizerConfig build() {
return new OnlineRecognizerConfig(this);
@@ -118,5 +121,10 @@ public class OnlineRecognizerConfig {
this.ruleFars = ruleFars;
return this;
}
public Builder setBlankPenalty(float blankPenalty) {
this.blankPenalty = blankPenalty;
return this;
}
}
}

View File

@@ -46,6 +46,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "blankPenalty", "F");
ans.blank_penalty = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");

View File

@@ -49,6 +49,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "blankPenalty", "F");
ans.blank_penalty = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");

View File

@@ -62,6 +62,7 @@ data class OfflineRecognizerConfig(
var hotwordsScore: Float = 1.5f,
var ruleFsts: String = "",
var ruleFars: String = "",
var blankPenalty: Float = 0.0f,
)
class OfflineRecognizer(

View File

@@ -71,6 +71,7 @@ data class OnlineRecognizerConfig(
var hotwordsScore: Float = 1.5f,
var ruleFsts: String = "",
var ruleFars: String = "",
var blankPenalty: Float = 0.0f,
)
data class OnlineRecognizerResult(