Inverse text normalization API of streaming ASR for various programming languages (#1022)

This commit is contained in:
Fangjun Kuang
2024-06-18 13:42:17 +08:00
committed by GitHub
parent 349d957da2
commit 6789c909d2
64 changed files with 849 additions and 55 deletions

View File

@@ -110,6 +110,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.ctc_fst_decoder_config.max_active =
SHERPA_ONNX_OR(config->ctc_fst_decoder_config.max_active, 3000);
recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, "");
recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, "");
if (config->model_config.debug) {
SHERPA_ONNX_LOGE("%s\n", recognizer_config.ToString().c_str());
}

View File

@@ -144,6 +144,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
float hotwords_score;
SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config;
const char *rule_fsts;
const char *rule_fars;
} SherpaOnnxOnlineRecognizerConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {

View File

@@ -190,9 +190,10 @@ if(NOT BUILD_SHARED_LIBS AND APPLE)
target_link_libraries(sherpa-onnx-core "-framework Foundation")
endif()
target_link_libraries(sherpa-onnx-core fstfar fst)
if(SHERPA_ONNX_ENABLE_TTS)
target_link_libraries(sherpa-onnx-core piper_phonemize)
target_link_libraries(sherpa-onnx-core fstfar fst)
target_link_libraries(sherpa-onnx-core cppjieba)
endif()

View File

@@ -425,9 +425,6 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
if (!itn_list_.empty()) {
for (const auto &tn : itn_list_) {
text = tn->Normalize(text);
if (config_.model_config.debug) {
SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
}
}
}

View File

@@ -4,6 +4,8 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include <utility>
#if __ANDROID_API__ >= 9
#include <strstream>
@@ -186,9 +188,6 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
if (!itn_list_.empty()) {
for (const auto &tn : itn_list_) {
text = tn->Normalize(text);
if (config_.model_config.debug) {
SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
}
}
}

View File

@@ -1,3 +1,7 @@
## 1.10.0
* Add inverse text normalization
## 1.9.30
* Add TTS

View File

@@ -111,11 +111,13 @@ class OnlineRecognizerConfig {
this.hotwordsFile = '',
this.hotwordsScore = 1.5,
this.ctcFstDecoderConfig = const OnlineCtcFstDecoderConfig(),
this.ruleFsts = '',
this.ruleFars = '',
});
@override
String toString() {
return 'OnlineRecognizerConfig(feat: $feat, model: $model, decodingMethod: $decodingMethod, maxActivePaths: $maxActivePaths, enableEndpoint: $enableEndpoint, rule1MinTrailingSilence: $rule1MinTrailingSilence, rule2MinTrailingSilence: $rule2MinTrailingSilence, rule3MinUtteranceLength: $rule3MinUtteranceLength, hotwordsFile: $hotwordsFile, hotwordsScore: $hotwordsScore, ctcFstDecoderConfig: $ctcFstDecoderConfig)';
return 'OnlineRecognizerConfig(feat: $feat, model: $model, decodingMethod: $decodingMethod, maxActivePaths: $maxActivePaths, enableEndpoint: $enableEndpoint, rule1MinTrailingSilence: $rule1MinTrailingSilence, rule2MinTrailingSilence: $rule2MinTrailingSilence, rule3MinUtteranceLength: $rule3MinUtteranceLength, hotwordsFile: $hotwordsFile, hotwordsScore: $hotwordsScore, ctcFstDecoderConfig: $ctcFstDecoderConfig, ruleFsts: $ruleFsts, ruleFars: $ruleFars)';
}
final FeatureConfig feat;
@@ -137,6 +139,8 @@ class OnlineRecognizerConfig {
final double hotwordsScore;
final OnlineCtcFstDecoderConfig ctcFstDecoderConfig;
final String ruleFsts;
final String ruleFars;
}
class OnlineRecognizerResult {
@@ -201,9 +205,13 @@ class OnlineRecognizer {
c.ref.ctcFstDecoderConfig.graph =
config.ctcFstDecoderConfig.graph.toNativeUtf8();
c.ref.ctcFstDecoderConfig.maxActive = config.ctcFstDecoderConfig.maxActive;
c.ref.ruleFsts = config.ruleFsts.toNativeUtf8();
c.ref.ruleFars = config.ruleFars.toNativeUtf8();
final ptr = SherpaOnnxBindings.createOnlineRecognizer?.call(c) ?? nullptr;
calloc.free(c.ref.ruleFars);
calloc.free(c.ref.ruleFsts);
calloc.free(c.ref.ctcFstDecoderConfig.graph);
calloc.free(c.ref.hotwordsFile);
calloc.free(c.ref.decodingMethod);

View File

@@ -205,6 +205,9 @@ final class SherpaOnnxOnlineRecognizerConfig extends Struct {
external double hotwordsScore;
external SherpaOnnxOnlineCtcFstDecoderConfig ctcFstDecoderConfig;
external Pointer<Utf8> ruleFsts;
external Pointer<Utf8> ruleFars;
}
final class SherpaOnnxSileroVadModelConfig extends Struct {

View File

@@ -15,6 +15,8 @@ public class OnlineRecognizerConfig {
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;
private final String ruleFsts;
private final String ruleFars;
private OnlineRecognizerConfig(Builder builder) {
this.featConfig = builder.featConfig;
@@ -27,6 +29,8 @@ public class OnlineRecognizerConfig {
this.maxActivePaths = builder.maxActivePaths;
this.hotwordsFile = builder.hotwordsFile;
this.hotwordsScore = builder.hotwordsScore;
this.ruleFsts = builder.ruleFsts;
this.ruleFars = builder.ruleFars;
}
public static Builder builder() {
@@ -48,6 +52,8 @@ public class OnlineRecognizerConfig {
private int maxActivePaths = 4;
private String hotwordsFile = "";
private float hotwordsScore = 1.5f;
private String ruleFsts = "";
private String ruleFars = "";
public OnlineRecognizerConfig build() {
return new OnlineRecognizerConfig(this);
@@ -102,5 +108,15 @@ public class OnlineRecognizerConfig {
this.hotwordsScore = hotwordsScore;
return this;
}
public Builder setRuleFsts(String ruleFsts) {
this.ruleFsts = ruleFsts;
return this;
}
public Builder setRuleFars(String ruleFars) {
this.ruleFars = ruleFars;
return this;
}
}
}

View File

@@ -37,6 +37,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");

View File

@@ -69,6 +69,8 @@ data class OnlineRecognizerConfig(
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
var ruleFsts: String = "",
var ruleFars: String = "",
)
data class OnlineRecognizerResult(