Support text normalization via rule FST (#407)
This commit is contained in:
@@ -14,30 +14,50 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
public:
|
||||
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
|
||||
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config.model.debug,
|
||||
model_->IsPiper()) {}
|
||||
model_->IsPiper()) {
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
tn_list_.reserve(files.size());
|
||||
for (const auto &f : files) {
|
||||
if (config.model.debug) {
|
||||
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
|
||||
}
|
||||
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
|
||||
: model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
|
||||
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config.model.debug,
|
||||
model_->IsPiper()) {}
|
||||
model_->IsPiper()) {
|
||||
if (!config.rule_fsts.empty()) {
|
||||
SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
|
||||
GeneratedAudio Generate(const std::string &_text, int64_t sid = 0,
|
||||
float speed = 1.0) const override {
|
||||
int32_t num_speakers = model_->NumSpeakers();
|
||||
if (num_speakers == 0 && sid != 0) {
|
||||
@@ -55,6 +75,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
sid = 0;
|
||||
}
|
||||
|
||||
std::string text = _text;
|
||||
if (config_.model.debug) {
|
||||
SHERPA_ONNX_LOGE("Raw text: %s", text.c_str());
|
||||
}
|
||||
|
||||
if (!tn_list_.empty()) {
|
||||
for (const auto &tn : tn_list_) {
|
||||
text = tn->Normalize(text);
|
||||
if (config_.model.debug) {
|
||||
SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||
if (x.empty()) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
@@ -98,7 +132,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineTtsConfig config_;
|
||||
std::unique_ptr<OfflineTtsVitsModel> model_;
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
|
||||
Lexicon lexicon_;
|
||||
};
|
||||
|
||||
|
||||
@@ -6,19 +6,44 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); }
|
||||
void OfflineTtsConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
|
||||
bool OfflineTtsConfig::Validate() const { return model.Validate(); }
|
||||
po->Register("tts-rule-fsts", &rule_fsts,
|
||||
"It not empty, it contains a list of rule FST filenames."
|
||||
"Multiple filenames are separated by a comma and they are "
|
||||
"applied from left to right. An example value: "
|
||||
"rule1.fst,rule2,fst,rule3.fst");
|
||||
}
|
||||
|
||||
bool OfflineTtsConfig::Validate() const {
|
||||
if (!rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(rule_fsts, ",", false, &files);
|
||||
for (const auto &f : files) {
|
||||
if (!FileExists(f)) {
|
||||
SHERPA_ONNX_LOGE("Rule fst %s does not exist. ", f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return model.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineTtsConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineTtsConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
os << "model=" << model.ToString() << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -21,10 +21,17 @@ namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTtsConfig {
|
||||
OfflineTtsModelConfig model;
|
||||
// If not empty, it contains a list of rule FST filenames.
|
||||
// Filenames are separated by a comma.
|
||||
// Example value: rule1.fst,rule2,fst,rule3.fst
|
||||
//
|
||||
// If there are multiple rules, they are applied from left to right.
|
||||
std::string rule_fsts;
|
||||
|
||||
OfflineTtsConfig() = default;
|
||||
explicit OfflineTtsConfig(const OfflineTtsModelConfig &model)
|
||||
: model(model) {}
|
||||
OfflineTtsConfig(const OfflineTtsModelConfig &model,
|
||||
const std::string &rule_fsts)
|
||||
: model(model), rule_fsts(rule_fsts) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/python/csrc/offline-tts.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
|
||||
|
||||
@@ -28,8 +30,10 @@ static void PybindOfflineTtsConfig(py::module *m) {
|
||||
using PyClass = OfflineTtsConfig;
|
||||
py::class_<PyClass>(*m, "OfflineTtsConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const OfflineTtsModelConfig &>(), py::arg("model"))
|
||||
.def(py::init<const OfflineTtsModelConfig &, const std::string &>(),
|
||||
py::arg("model"), py::arg("rule_fsts") = "")
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user